From 87a10e1e506ba45e359d083b2a115adce5bd3b40 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 20 Aug 2024 17:59:33 +0800 Subject: [PATCH] feat(huixiangdou): add `chat_with_repo` pipeline (#362) * feat(service): add parallel pipeline * feat(service): gradio streaming chat * style(llm_client.py): remove useless --- .github/scripts/doc_link_checker.py | 1 + README.md | 15 +- README_zh.md | 19 +- config.ini | 2 +- docs/full_dev_en.md | 4 +- docs/full_dev_zh.md | 4 +- huixiangdou/__init__.py | 2 +- huixiangdou/frontend/wechat.py | 2 +- huixiangdou/gradio.py | 150 ++++++-- huixiangdou/main.py | 11 +- huixiangdou/primitive/file_operation.py | 2 - huixiangdou/rag.py | 9 +- huixiangdou/server.py | 97 ++++- huixiangdou/service/__init__.py | 3 +- huixiangdou/service/feature_store.py | 3 +- huixiangdou/service/helper.py | 1 - huixiangdou/service/llm_client.py | 24 +- huixiangdou/service/llm_server_hybrid.py | 2 + huixiangdou/service/parallel_pipeline.py | 355 ++++++++++++++++++ huixiangdou/service/retriever.py | 121 +++--- .../service/{worker.py => serial_pipeline.py} | 110 +----- huixiangdou/service/session.py | 54 +++ huixiangdou/service/web_search.py | 11 +- requirements.txt | 2 +- sft/reconstruct_filter_annotate.py | 1 - tests/test_bge_reranker.py | 1 - tests/test_build_milvus_and_filter.py | 2 - tests/test_m3.py | 1 - tests/test_query_gradio.py | 8 +- web/proxy/main.py | 4 +- web/proxy/web_worker.py | 10 +- 31 files changed, 773 insertions(+), 258 deletions(-) create mode 100644 huixiangdou/service/parallel_pipeline.py rename huixiangdou/service/{worker.py => serial_pipeline.py} (86%) create mode 100644 huixiangdou/service/session.py diff --git a/.github/scripts/doc_link_checker.py b/.github/scripts/doc_link_checker.py index 9434fa83..cba51520 100644 --- a/.github/scripts/doc_link_checker.py +++ b/.github/scripts/doc_link_checker.py @@ -58,6 +58,7 @@ def analyze_doc(home, path): ref = ref[ref.find('#'):] fullpath = os.path.join(home, ref) if not os.path.exists(fullpath): + raise ValueError(fullpath) problem_list.append(ref) else: continue diff --git a/README.md b/README.md index a5b98c68..4b71eac4 100644 --- a/README.md +++ b/README.md @@ -30,12 +30,14 @@ English | [简体中文](README_zh.md) -HuixiangDou is a **group chat** assistant based on LLM (Large Language Model). +HuixiangDou is a **professional knowledge assistant** based on LLM. Advantages: -1. Design a three-stage pipeline of preprocess, rejection and response to cope with group chat scenario, answer user questions without message flooding, see [2401.08772](https://arxiv.org/abs/2401.08772), [2405.02817](https://arxiv.org/abs/2405.02817), [Hybrid Retrieval](./docs/knowledge_graph_en.md) and [Precision Report](./evaluation/). -2. No training required, with CPU-only, 2G, 10G and 80G configuration +1. Design three-stage pipelines of preprocess, rejection and response + * `chat_in_group` copes with **group chat** scenario, answer user questions without message flooding, see [2401.08772](https://arxiv.org/abs/2401.08772), [2405.02817](https://arxiv.org/abs/2405.02817), [Hybrid Retrieval](./docs/knowledge_graph_en.md) and [Precision Report](./evaluation/) + * `chat_with_repo` for **real-time streaming** chat +2. No training required, with CPU-only, 2G, 10G, 20G and 80G configuration 3. Offers a complete suite of Web, Android, and pipeline source code, industrial-grade and commercially viable Check out the [scenes in which HuixiangDou are running](./huixiangdou-inside.md) and join [WeChat Group](resource/figures/wechat.jpg) to try AI assistant inside. @@ -46,6 +48,7 @@ If this helps you, please give it a star ⭐ Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/detail/tpoisonooo/huixiangdou-web), where you can create knowledge base, update positive and negative examples, turn on web search, test chat, and integrate into Feishu/WeChat groups. See [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn) and [YouTube](https://www.youtube.com/watch?v=ylXrT-Tei-Y) ! +- \[2024/08\] `chat_with_repo` [pipeline](./huixiangdou/service/parallel_pipeline.py) 👍 - \[2024/07\] Image and text retrieval & Removal of `langchain` 👍 - \[2024/07\] [Hybrid Knowledge Graph and Dense Retrieval](./docs/knowledge_graph_en.md) improve 1.7% F1 score 🎯 - \[2024/06\] [Evaluation of chunksize, splitter, and text2vec model](./evaluation) 🎯 @@ -221,7 +224,9 @@ python3 -m huixiangdou.main --standalone python3 -m huixiangdou.gradio ``` -Or run a server to listen 23333: +https://github.com/user-attachments/assets/9e5dbb30-1dc1-42ad-a7d4-dc7380676554 + +Or run a server to listen 23333, default pipeline is `chat_with_repo`: ```bash python3 -m huixiangdou.server @@ -368,7 +373,7 @@ Contributors have provided [Android tools](./android) to interact with WeChat. T 3. How to access other local LLM / After access, the effect is not ideal? - Open [hybrid llm service](./huixiangdou/service/llm_server_hybrid.py), add a new LLM inference implementation. - - Refer to [test_intention_prompt and test data](./tests/test_intention_prompt.py), adjust prompt and threshold for the new model, and update them into [worker.py](./huixiangdou/service/worker.py). + - Refer to [test_intention_prompt and test data](./tests/test_intention_prompt.py), adjust prompt and threshold for the new model, and update them into [prompt.py](./huixiangdou/service/prompt.py). 4. What if the response is too slow/request always fails? diff --git a/README_zh.md b/README_zh.md index 861181fc..1b39f858 100644 --- a/README_zh.md +++ b/README_zh.md @@ -29,10 +29,12 @@ -茴香豆是一个基于 LLM 的**群聊**知识助手,优势: +茴香豆是一个基于 LLM 的专业知识助手,优势: -1. 设计预处理、拒答、响应三阶段 pipeline 应对群聊场景,解答问题同时不会消息泛滥。精髓见 [2401.08772](https://arxiv.org/abs/2401.08772),[2405.02817](https://arxiv.org/abs/2405.02817),[混合检索](./docs/knowledge_graph_zh.md)和[业务数据精度测试](./evaluation) -2. 无需训练适用各行业,提供 CPU-only、2G、10G、80G 规格配置 +1. 设计预处理、拒答、响应三阶段 pipeline: + * `chat_in_group` 群聊场景,解答问题时不会消息泛滥。见 [2401.08772](https://arxiv.org/abs/2401.08772),[2405.02817](https://arxiv.org/abs/2405.02817),[混合检索](./docs/knowledge_graph_zh.md)和[业务数据精度测试](./evaluation) + * `chat_with_repo` 实时聊天场景,响应更快 +2. 无需训练适用各行业,提供 CPU-only、2G、10G、20G、80G 规格配置 3. 提供一整套前后端 web、android、算法源码,工业级开源可商用 查看[茴香豆已运行在哪些场景](./huixiangdou-inside.md);加入[微信群](resource/figures/wechat.jpg)直接体验群聊助手效果。 @@ -45,6 +47,7 @@ Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn) 和 [YouTube](https://www.youtube.com/watch?v=ylXrT-Tei-Y)。 +- \[2024/08\] `chat_with_repo` [pipeline](./huixiangdou/service/parallel_pipeline.py) - \[2024/07\] 图文检索 & 移除 `langchain` 👍 - \[2024/07\] [混合知识图谱和稠密检索,F1 提升 1.7%](./docs/knowledge_graph_zh.md) 🎯 - \[2024/06\] [评估 chunksize,splitter 和 text2vec 模型](./evaluation) 🎯 @@ -216,10 +219,14 @@ python3 -m huixiangdou.main --standalone 💡 也可以启动 `gradio` 搭建一个简易的 Web UI,默认绑定 7860 端口: ```bash -python3 -m huixiangdou.gradio +python3 -m huixiangdou.gradio +# 若已单独运行 `llm_server_hybrid.py`,可以 +# python3 -m huixiangdou.gradio --no-standalone ``` -或者启动服务端,监听 23333 端口: +https://github.com/user-attachments/assets/9e5dbb30-1dc1-42ad-a7d4-dc7380676554 + +或者启动服务端,监听 23333 端口。默认使用 `chat_with_repo` pipeline: ```bash python3 -m huixiangdou.server @@ -364,7 +371,7 @@ python3 tests/test_query_gradio.py 3. 如何接入其他 local LLM / 接入后效果不理想怎么办? - 打开 [hybrid llm service](./huixiangdou/service/llm_server_hybrid.py),增加新的 LLM 推理实现 - - 参照 [test_intention_prompt 和测试数据](./tests/test_intention_prompt.py),针对新模型调整 prompt 和阈值,更新到 [worker.py](./huixiangdou/service/worker.py) + - 参照 [test_intention_prompt 和测试数据](./tests/test_intention_prompt.py),针对新模型调整 prompt 和阈值,更新到 [prompt.py](./huixiangdou/service/prompt.py) 4. 响应太慢/网络请求总是失败怎么办? diff --git a/config.ini b/config.ini index f40da1c4..c5f4de14 100644 --- a/config.ini +++ b/config.ini @@ -25,7 +25,7 @@ engine = "serper" # For ddgs, see https://pypi.org/project/duckduckgo-search # For serper, check https://serper.dev/api-key to get a free API key serper_x_api_key = "YOUR-API-KEY-HERE" -domain_partial_order = ["openai.com", "pytorch.org", "readthedocs.io", "nvidia.com", "stackoverflow.com", "juejin.cn", "zhuanlan.zhihu.com", "www.cnblogs.com"] +domain_partial_order = ["arxiv.org", "openai.com", "pytorch.org", "readthedocs.io", "nvidia.com", "stackoverflow.com", "juejin.cn", "zhuanlan.zhihu.com", "www.cnblogs.com"] save_dir = "logs/web_search_result" [llm] diff --git a/docs/full_dev_en.md b/docs/full_dev_en.md index a1e13d0c..ed6139cd 100644 --- a/docs/full_dev_en.md +++ b/docs/full_dev_en.md @@ -74,6 +74,6 @@ The basic version may not perform well. You can enable these features to enhance It is often unavoidable to adjust parameters with respect to business scenarios. - - Refer to [data.json](./tests/data.json) to add real data, run [test_intention_prompt.py](./tests/test_intention_prompt.py) to get suitable prompts and thresholds, and update them into [worker](./huixiangdou/service/worker.py). - - Adjust the [number of search results](./huixiangdou/service/worker.py) based on the maximum length supported by the model. + - Refer to [data.json](../tests/data.json) to add real data, run [test_intention_prompt.py](../tests/test_intention_prompt.py) to get suitable prompts and thresholds, and update them into [prompt.py](../huixiangdou/service/prompt.py). + - Adjust the [number of search results](../huixiangdou/service/serial_pipeline.py) based on the maximum length supported by the model. - Update `web_search.domain_partial_order` in `config.ini` according to your scenarios. diff --git a/docs/full_dev_zh.md b/docs/full_dev_zh.md index 76fc9363..dc53b20c 100644 --- a/docs/full_dev_zh.md +++ b/docs/full_dev_zh.md @@ -73,6 +73,6 @@ 针对业务场景调参往往不可避免。 - - 参照 [data.json](./tests/data.json) 增加真实数据,运行 [test_intention_prompt.py](./tests/test_intention_prompt.py) 得到合适的 prompt 和阈值,更新进 [worker](./huixiangdou/service/worker.py) - - 根据模型支持的最大长度,调整[搜索结果个数](./huixiangdou/service/worker.py) + - 参照 [data.json](../tests/data.json) 增加真实数据,运行 [test_intention_prompt.py](../tests/test_intention_prompt.py) 得到合适的 prompt 和阈值,更新进 [prompt.py](../huixiangdou/service/prompt.py) + - 根据模型支持的最大长度,调整[搜索结果个数](../huixiangdou/service/serial_pipeline.py) - 按照场景偏好,修改 config.ini 中的 `web_search.domain_partial_order`,即搜索结果偏序 diff --git a/huixiangdou/__init__.py b/huixiangdou/__init__.py index d037b85f..e397d2cb 100644 --- a/huixiangdou/__init__.py +++ b/huixiangdou/__init__.py @@ -6,7 +6,7 @@ from .service import FeatureStore # noqa E401 from .service import HybridLLMServer # noqa E401 from .service import WebSearch # noqa E401 -from .service import Worker # noqa E401 +from .service import SerialPipeline, ParallelPipeline # no E401 from .service import build_reply_text # noqa E401 from .service import llm_serve # noqa E401 from .version import __version__ diff --git a/huixiangdou/frontend/wechat.py b/huixiangdou/frontend/wechat.py index 316ce3ab..f2fffc92 100644 --- a/huixiangdou/frontend/wechat.py +++ b/huixiangdou/frontend/wechat.py @@ -845,7 +845,7 @@ def loop(self, worker): def parse_args(): """Parse args.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='wechat server.') parser.add_argument('--work_dir', type=str, default='workdir', diff --git a/huixiangdou/gradio.py b/huixiangdou/gradio.py index d0d2b5f4..ebdc2b14 100644 --- a/huixiangdou/gradio.py +++ b/huixiangdou/gradio.py @@ -4,19 +4,19 @@ import time import pdb from multiprocessing import Process, Value - +import asyncio import cv2 import gradio as gr import pytoml from loguru import logger - +from typing import List from huixiangdou.primitive import Query -from huixiangdou.service import ErrorCode, Worker, llm_serve, start_llm_server - +from huixiangdou.service import ErrorCode, SerialPipeline, ParallelPipeline, llm_serve, start_llm_server +import json def parse_args(): """Parse args.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='SerialPipeline.') parser.add_argument('--work_dir', type=str, default='workdir', @@ -25,7 +25,7 @@ def parse_args(): '--config_path', default='config.ini', type=str, - help='Worker configuration path. Default value is config.ini') + help='SerialPipeline configuration path. Default value is config.ini') parser.add_argument('--standalone', action='store_true', default=True, @@ -37,7 +37,60 @@ def parse_args(): args = parser.parse_args() return args -def predict(text, image): +language='en' +enable_web_search=False +pipeline='chat_with_repo' +main_args = None +paralle_assistant = None +serial_assistant = None + +def on_language_changed(value:str): + global language + print(value) + language = value + +def on_pipeline_changed(value:str): + global pipeline + print(value) + pipeline = value + +def on_web_search_changed(value: str): + global enable_web_search + print(value) + if 'no' in value: + enable_web_search = False + else: + enable_web_search = True + + +def format_refs(refs: List[str]): + refs_filter = list(set(refs)) + if len(refs) < 1: + return '' + text = '' + if language == 'zh': + text += '参考资料:\r\n' + else: + text += '**References:**\r\n' + + for file_or_url in refs_filter: + text += '* {}\r\n'.format(file_or_url) + text += '\r\n' + return text + + +async def predict(text:str, image:str): + global language + global enable_web_search + global pipeline + global main_args + global serial_assistant + global paralle_assistant + + with open('query.txt', 'a') as f: + f.write(json.dumps({'data': text})) + f.write('\n') + if image is not None: filename = 'image.png' image_path = os.path.join(args.work_dir, filename) @@ -45,42 +98,81 @@ def predict(text, image): else: image_path = None - assistant = Worker(work_dir=args.work_dir, config_path=args.config_path) query = Query(text, image_path) + if 'chat_in_group' in pipeline: + if serial_assistant is None: + serial_assistant = SerialPipeline(work_dir=main_args.work_dir, config_path=main_args.config_path) + args = {'query':query, 'history': [], 'groupname':''} + pipeline = {'status': {}} + debug = dict() + stream_chat_content = '' + for sess in serial_assistant.generate(**args): + if len(sess.delta) > 0: + # start chat, display + stream_chat_content += sess.delta + yield stream_chat_content + else: + status = { + "state":str(sess.code), + "response": sess.response, + "refs": sess.references + } + pipeline['status'] = status + pipeline['debug'] = sess.debug + + json_str = json.dumps(pipeline, indent=2, ensure_ascii=False) + yield json_str - pipeline = {'step': []} - debug = dict() - for sess in assistant.generate(query=query, history=[], groupname=''): - status = { - "state":str(sess.code), - "response": sess.response, - "refs": sess.references - } + else: + if paralle_assistant is None: + paralle_assistant = ParallelPipeline(work_dir=main_args.work_dir, config_path=main_args.config_path) + args = {'query':query, 'history':[], 'language':language} + args['enable_web_search'] = enable_web_search - print(status) - pipeline['step'].append(status) - pipeline['debug'] = sess.debug + sentence = '' + async for sess in paralle_assistant.generate(**args): + if sentence == '' and len(sess.references) > 0: + sentence = format_refs(sess.references) - json_str = json.dumps(pipeline, indent=2, ensure_ascii=False) - yield json_str + if len(sess.delta) > 0: + sentence += sess.delta + yield sentence + + yield sentence if __name__ == '__main__': - args = parse_args() + main_args = parse_args() # start service - if args.standalone is True: + if main_args.standalone is True: # hybrid llm serve - start_llm_server(config_path=args.config_path) + start_llm_server(config_path=main_args.config_path) - with gr.Blocks() as demo: + with gr.Blocks(theme=gr.themes.Soft(), title='HuixiangDou AI assistant', analytics_enabled=True) as demo: + with gr.Row(): + gr.Markdown(""" + #### [HuixiangDou](https://github.com/internlm/huixiangdou) AI assistant + """, label='Reply', header_links=True, line_breaks=True,) with gr.Row(): - input_question = gr.TextArea(label='Input the question.') - input_image = gr.Image(label='Upload Image.') + with gr.Column(): + ui_pipeline = gr.Radio(["chat_with_repo", "chat_in_group"], label="Pipeline type", info="Group-chat is slow but accurate and safe, default value is `chat_with_repo`") + ui_pipeline.change(fn=on_pipeline_changed, inputs=ui_pipeline, outputs=[]) + with gr.Column(): + ui_language = gr.Radio(["en", "zh"], label="Language", info="Use `en` by default ") + ui_language.change(fn=on_language_changed, inputs=ui_language, outputs=[]) + with gr.Column(): + ui_web_search = gr.Radio(["no", "yes"], label="Enable web search", info="Disable by default ") + ui_web_search.change(on_web_search_changed, inputs=ui_web_search, outputs=[]) + + with gr.Row(): + input_question = gr.TextArea(label='Input your question', placeholder='how to install mmpose ?', show_copy_button=True, lines=9) + input_image = gr.Image(label='[Optional] Image-text retrieval needs `config-multimodal.ini`') with gr.Row(): run_button = gr.Button() with gr.Row(): - result = gr.TextArea(label='HuixiangDou pipline status', show_copy_button=True) + result = gr.Markdown('>Text reply or inner status callback here, depends on `pipeline type`', label='Reply', show_label=True, header_links=True, line_breaks=True, show_copy_button=True) + # result = gr.TextArea(label='Reply', show_copy_button=True, placeholder='Text Reply or inner status callback, depends on `pipeline type`') + run_button.click(predict, [input_question, input_image], [result]) - demo.queue() demo.launch(share=False, server_name='0.0.0.0', debug=True) diff --git a/huixiangdou/main.py b/huixiangdou/main.py index 18ad4cbc..f7dddf99 100755 --- a/huixiangdou/main.py +++ b/huixiangdou/main.py @@ -11,12 +11,12 @@ from loguru import logger from termcolor import colored -from .service import ErrorCode, Worker, build_reply_text, start_llm_server +from .service import ErrorCode, SerialPipeline, build_reply_text, start_llm_server def parse_args(): """Parse args.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='SerialPipeline.') parser.add_argument('--work_dir', type=str, default='workdir', @@ -25,7 +25,7 @@ def parse_args(): '--config_path', default='config.ini', type=str, - help='Worker configuration path. Default value is config.ini') + help='SerialPipeline configuration path. Default value is config.ini') parser.add_argument('--standalone', action='store_true', default=False, @@ -191,7 +191,7 @@ def run(): with open(args.config_path, encoding='utf8') as f: fe_config = pytoml.load(f)['frontend'] logger.info('Config loaded.') - assistant = Worker(work_dir=args.work_dir, config_path=args.config_path) + assistant = SerialPipeline(work_dir=args.work_dir, config_path=args.config_path) fe_type = fe_config['type'] if fe_type == 'none': @@ -209,8 +209,5 @@ def run(): f'unsupported fe_config.type {fe_type}, please read `config.ini` description.' # noqa E501 ) - # server_process.join() - - if __name__ == '__main__': run() diff --git a/huixiangdou/primitive/file_operation.py b/huixiangdou/primitive/file_operation.py index 5279bcd0..3123e10d 100644 --- a/huixiangdou/primitive/file_operation.py +++ b/huixiangdou/primitive/file_operation.py @@ -252,12 +252,10 @@ def get_pdf_files(directory): text, error = opr.read(pdf_path) print('processing {}'.format(pdf_path)) if error is not None: - # pdb.set_trace() print('') else: if text is not None: print(len(text)) else: - # pdb.set_trace() print('') diff --git a/huixiangdou/rag.py b/huixiangdou/rag.py index b5aefe9b..dc88f421 100644 --- a/huixiangdou/rag.py +++ b/huixiangdou/rag.py @@ -11,7 +11,7 @@ import requests from loguru import logger -from .service import ErrorCode, Worker, llm_serve +from .service import ErrorCode, SerialPipeline, llm_serve class Task: @@ -41,7 +41,7 @@ def to_json_str(self): def parse_args(): """Parse args.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='SerialPipeline.') parser.add_argument('--work_dir', type=str, default='workdir', @@ -50,7 +50,7 @@ def parse_args(): '--config_path', default='config-alignment.ini', type=str, - help='Worker configuration path. Default value is config.ini') + help='SerialPipeline configuration path. Default value is config.ini') parser.add_argument( '--input', default='resource/rag_example_input.json', @@ -75,10 +75,9 @@ def parse_args(): def rag(process_id: int, task: list, output_dir: str): """Extract structured output with RAG.""" - assistant = Worker(work_dir=args.work_dir, config_path=args.config_path) + assistant = SerialPipeline(work_dir=args.work_dir, config_path=args.config_path) # assistant.TOPIC_TEMPLATE = '告诉我这句话的关键字和主题,直接说主题不要解释:“{}”' - output_path = os.path.join(output_dir, 'output{}.json'.format(process_id)) for item in task: query = item.query diff --git a/huixiangdou/server.py b/huixiangdou/server.py index 7d7e164d..b48cefd7 100644 --- a/huixiangdou/server.py +++ b/huixiangdou/server.py @@ -8,7 +8,7 @@ from loguru import logger from termcolor import colored -from .service import ErrorCode, Worker, start_llm_server +from .service import ErrorCode, SerialPipeline, ParallelPipeline, start_llm_server from .primitive import Query import asyncio from fastapi import FastAPI, APIRouter @@ -17,14 +17,26 @@ from pydantic import BaseModel import uvicorn import json +from typing import List -assistant = Worker(work_dir='workdir', config_path='config.ini') +assistant = None app = FastAPI(docs_url='/') class Talk(BaseModel): text: str image: str = '' +def format_refs(refs: List[str]): + refs_filter = list(set(refs)) + if len(refs) < 1: + return '' + + text = '**References:**\r\n' + for file_or_url in refs_filter: + text += '* {}\r\n'.format(file_or_url) + text += '\r\n' + return text + @app.post("/huixiangdou_inference") async def huixiangdou_inference(talk: Talk): global assistant @@ -32,17 +44,27 @@ async def huixiangdou_inference(talk: Talk): pipeline = {'step': []} debug = dict() - for sess in assistant.generate(query=query, history=[], groupname=''): - status = { - "state":str(sess.code), - "response": sess.response, - "refs": sess.references - } + if type(assistant) is SerialPipeline: + for sess in assistant.generate(query=query): + status = { + "state":str(sess.code), + "response": sess.response, + "refs": sess.references + } - pipeline['step'].append(status) - pipeline['debug'] = sess.debug + pipeline['step'].append(status) + pipeline['debug'] = sess.debug + return pipeline + + else: + sentence = '' + async for sess in assistant.generate(query=query, enable_web_search=False): + if sentence == '' and len(sess.references) > 0: + sentence = format_refs(sess.references) - return pipeline + if len(sess.delta) > 0: + sentence += sess.delta + return sentence @app.post("/huixiangdou_stream") @@ -54,7 +76,7 @@ async def huixiangdou_stream(talk: Talk): debug = dict() def event_stream(): - for sess in assistant.generate(query=query, history=[], groupname=''): + for sess in assistant.generate(query=query): status = { "state":str(sess.code), "response": sess.response, @@ -64,7 +86,56 @@ def event_stream(): pipeline['step'].append(status) pipeline['debug'] = sess.debug yield json.dumps(pipeline) - return StreamingResponse(event_stream(), media_type="text/event-stream") + + async def event_stream_async(): + sentence = '' + async for sess in assistant.generate(query=query, enable_web_search=False): + if sentence == '' and len(sess.references) > 0: + sentence = format_refs(sess.references) + + if len(sess.delta) > 0: + sentence += sess.delta + yield sentence + + if type(assistant) is SerialPipeline: + return StreamingResponse(event_stream(), media_type="text/event-stream") + else: + return StreamingResponse(event_stream_async(), media_type="text/event-stream") + +def parse_args(): + """Parse args.""" + parser = argparse.ArgumentParser(description='SerialPipeline.') + parser.add_argument('--work_dir', + type=str, + default='workdir', + help='Working directory.') + parser.add_argument( + '--config_path', + default='config.ini', + type=str, + help='Configuration path. Default value is config.ini') + parser.add_argument('--pipeline', type=str, choices=['chat_with_repo', 'chat_in_group'], default='chat_with_repo', + help='Select pipeline type for difference scenario, default value is `chat_with_repo`') + parser.add_argument('--standalone', + action='store_true', + default=True, + help='Auto deploy required Hybrid LLM Service.') + parser.add_argument('--no-standalone', + action='store_false', + dest='standalone', # 指定与上面参数相同的目标 + help='Do not auto deploy required Hybrid LLM Service.') + args = parser.parse_args() + return args if __name__ == '__main__': + args = parse_args() + # start service + if args.standalone is True: + # hybrid llm serve + start_llm_server(config_path=main_args.config_path) + # setup chat service + if 'chat_with_repo' in args.pipeline: + assistant = ParallelPipeline(work_dir=args.work_dir, config_path=args.config_path) + elif 'chat_in_group' in args.pipeline: + assistant = SerialPipeline(work_dir=args.work_dir, config_path=args.config_path) uvicorn.run(app, host='0.0.0.0', port=23333, log_level='info') diff --git a/huixiangdou/service/__init__.py b/huixiangdou/service/__init__.py index fc30a7c0..1ee853c5 100644 --- a/huixiangdou/service/__init__.py +++ b/huixiangdou/service/__init__.py @@ -12,4 +12,5 @@ start_llm_server) from .retriever import CacheRetriever, Retriever # noqa E401 from .web_search import WebSearch # noqa E401 -from .worker import Worker # noqa E401 +from .serial_pipeline import SerialPipeline +from .parallel_pipeline import ParallelPipeline diff --git a/huixiangdou/service/feature_store.py b/huixiangdou/service/feature_store.py index 7c4d99fe..fe31c786 100644 --- a/huixiangdou/service/feature_store.py +++ b/huixiangdou/service/feature_store.py @@ -314,7 +314,7 @@ def test_reject(retriever: Retriever, sample: str = None): real_questions = json.load(f) for example in real_questions: - relative = retriever.is_relative(example) + relative, score = retriever.is_relative(example) if relative: logger.warning(f'process query: {example}') @@ -330,7 +330,6 @@ def test_reject(retriever: Retriever, sample: str = None): with open('workdir/negative.txt', 'a+') as f: f.write(example) f.write('\n') - empty_cache() diff --git a/huixiangdou/service/helper.py b/huixiangdou/service/helper.py index f942601a..968dec0f 100644 --- a/huixiangdou/service/helper.py +++ b/huixiangdou/service/helper.py @@ -24,7 +24,6 @@ class TaskCode(Enum): CHAT = 'chat' CHAT_RESPONSE = 'chat_response' - class ErrorCode(Enum): """Define an enumerated type for error codes, each has a numeric value and a description. diff --git a/huixiangdou/service/llm_client.py b/huixiangdou/service/llm_client.py index d7cde265..800c4d1a 100644 --- a/huixiangdou/service/llm_client.py +++ b/huixiangdou/service/llm_client.py @@ -3,6 +3,7 @@ import argparse import json import aiohttp +import re import pytoml import requests @@ -132,7 +133,7 @@ def generate_response(self, prompt, history=[], backend='local'): ) return '' - async def generate_response_async(self, prompt, history=[], backend='local'): + async def chat_stream(self, prompt, history=[], backend='local'): """Generate a stream response from the chat service. Args: @@ -143,7 +144,8 @@ async def generate_response_async(self, prompt, history=[], backend='local'): Returns: str: Generated response from the chat service. """ - url = self.llm_config['client_stream_url'] + sync_url = self.llm_config['client_url'] + stream_url = sync_url.replace('/inference', '/stream') real_backend, max_length = self.auto_fix(backend=backend) if len(prompt) > max_length: @@ -152,6 +154,7 @@ async def generate_response_async(self, prompt, history=[], backend='local'): ) prompt = prompt[0:max_length] + sse_pattern = re.compile(r'data: (.*?)(?=\r\n\r\n)', re.DOTALL) try: headers = {'Content-Type': 'application/json'} data_history = [] @@ -164,17 +167,16 @@ async def generate_response_async(self, prompt, history=[], backend='local'): } async with aiohttp.ClientSession() as session: - async with session.post(url, headers=headers, data=json.dumps(data)) as response: + async with session.post(stream_url, headers=headers, data=json.dumps(data)) as response: # 确保请求成功 if response.status == 200: async for chunk in response.content.iter_any(): - chunk_str = chunk.decode().strip() - mines = chunk_str.split('\r\n\r\n') - - for mime_str in mines: - pos = mime_str.find('data: ') + len('data: ') - content = mime_str[pos:] - yield content + chunk_data = chunk.decode() + messages = sse_pattern.findall(chunk_data) + for message in messages: + if '\r\ndata: ' in message: + message = message.replace('\r\ndata: ', '\r\n') + yield message else: raise Exception(response.status) @@ -211,7 +213,7 @@ def parse_args(): # backend='remote')) async def wrap_as_coroutine(): - async for text in client.generate_response_async('请问 ncnn 全称是啥'): + async for text in client.chat_stream('请问 ncnn 全称是啥'): print(text, end='', flush=True) import asyncio diff --git a/huixiangdou/service/llm_server_hybrid.py b/huixiangdou/service/llm_server_hybrid.py index e38d782f..4543e0e5 100644 --- a/huixiangdou/service/llm_server_hybrid.py +++ b/huixiangdou/service/llm_server_hybrid.py @@ -506,6 +506,7 @@ def llm_serve(config_path: str, server_ready: Value): async def inference(talk: Talk): """Call local llm inference.""" + logger.info(talk) prompt = talk.prompt history = talk.history @@ -521,6 +522,7 @@ async def inference(talk: Talk): async def stream(talk: Talk): """Call local llm inference.""" + logger.info(talk) prompt = talk.prompt history = talk.history diff --git a/huixiangdou/service/parallel_pipeline.py b/huixiangdou/service/parallel_pipeline.py new file mode 100644 index 00000000..321ec8ff --- /dev/null +++ b/huixiangdou/service/parallel_pipeline.py @@ -0,0 +1,355 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Pipeline.""" +import argparse +import asyncio +import datetime +import json +import os +import re +import time +import pdb +import copy +from abc import ABC, abstractmethod +from typing import List, Tuple, Union, Generator + +import pytoml +from loguru import logger + +from huixiangdou.primitive import Query, Chunk + +from .helper import ErrorCode, is_truth +from .llm_client import ChatClient +from .retriever import CacheRetriever, Retriever +from .sg_search import SourceGraphProxy +from .session import Session +from .web_search import WebSearch +from .prompt import (SCORING_QUESTION_TEMPLTE_CN, CR_NEED_CN, CR_CN, TOPIC_TEMPLATE_CN, SCORING_RELAVANCE_TEMPLATE_CN, GENERATE_TEMPLATE_CN, KEYWORDS_TEMPLATE_CN, PERPLESITY_TEMPLATE_CN, SECURITY_TEMAPLTE_CN) +from .prompt import (SCORING_QUESTION_TEMPLTE_EN, CR_NEED_EN, CR_EN, TOPIC_TEMPLATE_EN, SCORING_RELAVANCE_TEMPLATE_EN, GENERATE_TEMPLATE_EN, KEYWORDS_TEMPLATE_EN, PERPLESITY_TEMPLATE_EN, SECURITY_TEMAPLTE_EN) + +class PreprocNode: + """PreprocNode is for coreference resolution and scoring based on group + chats. + + See https://arxiv.org/abs/2405.02817 + """ + + def __init__(self, config: dict, llm: ChatClient, language: str): + self.llm = llm + self.enable_cr = config['worker']['enable_cr'] + + if language == 'zh': + self.SCORING_QUESTION_TEMPLTE = SCORING_QUESTION_TEMPLTE_CN + self.CR = CR_CN + else: + self.SCORING_QUESTION_TEMPLTE = SCORING_QUESTION_TEMPLTE_EN + self.CR = CR_EN + + def process(self, sess: Session) -> Generator[Session, None, None]: + # check input + if sess.query.text is None or len(sess.query.text) < 6: + sess.code = ErrorCode.QUESTION_TOO_SHORT + yield sess + return + + prompt = self.SCORING_QUESTION_TEMPLTE.format(sess.query.text) + truth, logs = is_truth(llm=self.llm, + prompt=prompt, + throttle=6, + default=3) + sess.debug['PreprocNode_is_question'] = logs + if not truth: + sess.code = ErrorCode.NOT_A_QUESTION + yield sess + return + + if not self.enable_cr: + yield sess + return + + if len(sess.groupchats) < 1: + logger.debug('history conversation empty, skip CR') + yield sess + return + + talks = [] + + # rewrite user_id to ABCD.. + name_map = dict() + name_int = ord('A') + for msg in sess.groupchats: + sender = msg.sender + if sender not in name_map: + name_map[sender] = chr(name_int) + name_int += 1 + talks.append({'sender': name_map[sender], 'content': msg.query}) + + talk_str = json.dumps(talks, ensure_ascii=False) + prompt = self.CR.format(talk_str, sess.query.text) + self.cr = self.llm.generate_response(prompt=prompt, backend='remote') + if self.cr.startswith('“') and self.cr.endswith('”'): + self.cr = self.cr[1:len(self.cr) - 1] + if self.cr.startswith('"') and self.cr.endswith('"'): + self.cr = self.cr[1:len(self.cr) - 1] + sess.debug['cr'] = self.cr + + # rewrite query + queries = [sess.query.text, self.cr] + self.query = '\n'.join(queries) + logger.debug('merge query and cr, query: {} cr: {}'.format( + self.query, self.cr)) + + +class Text2vecRetrieval: + """Text2vecNode is for retrieve from knowledge base.""" + + def __init__(self, config: dict, llm: ChatClient, retriever: Retriever, + language: str): + self.llm = llm + self.retriever = retriever + llm_config = config['llm'] + self.context_max_length = llm_config['server'][ + 'local_llm_max_text_length'] + if llm_config['enable_remote']: + self.context_max_length = llm_config['server'][ + 'remote_llm_max_text_length'] + if language == 'zh': + self.TOPIC_TEMPLATE = TOPIC_TEMPLATE_CN + self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_CN + self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_CN + else: + self.TOPIC_TEMPLATE = TOPIC_TEMPLATE_EN + self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_EN + self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_EN + self.max_length = self.context_max_length - 2 * len( + self.GENERATE_TEMPLATE) + + async def process_coroutine(self, sess: Session) -> Session: + """Try get reply with text2vec & rerank model.""" + + # retrieve from knowledge base + sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query.text) + # sess.parallel_chunks = self.retriever.text2vec_retrieve(query=sess.query.text) + return sess + +class WebSearchRetrieval: + """WebSearchNode is for web search, use `ddgs` or `serper`""" + + def __init__(self, config: dict, config_path: str, llm: ChatClient, + language: str): + self.llm = llm + self.config_path = config_path + self.enable = config['worker']['enable_web_search'] + llm_config = config['llm'] + self.context_max_length = llm_config['server'][ + 'local_llm_max_text_length'] + self.language = language + if llm_config['enable_remote']: + self.context_max_length = llm_config['server'][ + 'remote_llm_max_text_length'] + if language == 'zh': + self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_CN + self.KEYWORDS_TEMPLATE = KEYWORDS_TEMPLATE_CN + else: + self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_EN + self.KEYWORDS_TEMPLATE = KEYWORDS_TEMPLATE_EN + + async def process(self, sess: Session) -> Generator[Session, None, None]: + """Try web search.""" + + if not self.enable: + logger.debug('disable web_search') + yield sess + return + + engine = WebSearch(config_path=self.config_path, language=self.language) + + prompt = self.KEYWORDS_TEMPLATE.format(sess.groupname, sess.query.text) + search_keywords = self.llm.generate_response(prompt) + search_keywords = search_keywords.replace('"', '') + sess.debug['WebSearchNode_keywords'] = prompt + + articles, error = await asyncio.to_thread(engine.get, search_keywords, 4) + + if error is not None: + sess.code = ErrorCode.WEB_SEARCH_FAIL + sess.parallel_chunks = [] + yield sess + return + + if len(articles) < 1: + sess.code = ErrorCode.NO_SEARCH_RESULT + sess.parallel_chunks = [] + yield sess + return + + for article_id, article in enumerate(articles): + article.cut(0, self.context_max_length) + c = Chunk(content_or_path=article.content, metadata={'source': article.source}) + sess.parallel_chunks.append(c) + yield sess + + async def process_coroutine(self, sess: Session) -> Session: + results = [] + async for value in self.process(sess=sess): + results.append(value) + return results[-1] + + +class ReduceGenerate: + def __init__(self, config: dict, llm: ChatClient, retriever: CacheRetriever, language: str): + self.llm = llm + self.retriever = retriever + if language == 'zh': + self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_CN + else: + self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_EN + llm_config = config['llm'] + self.context_max_length = llm_config['server']['local_llm_max_text_length'] + if llm_config['enable_remote']: + self.context_max_length = llm_config['server']['remote_llm_max_text_length'] + + async def process(self, sess: Session) -> Generator[Session, None, None]: + question = sess.query.text + history = sess.history + + if len(sess.parallel_chunks) < 1: + # direct chat + async for part in self.llm.chat_stream(prompt=question, history=history): + sess.delta = part + yield sess + else: + _, context_str, references = self.retriever.rerank_fuse(query=sess.query, chunks=sess.parallel_chunks, context_max_length=self.context_max_length) + sess.references = references + prompt = self.GENERATE_TEMPLATE.format(context_str, sess.query) + async for part in self.llm.chat_stream(prompt=prompt, history=history): + sess.delta = part + yield sess + + +class ParallelPipeline: + """The ParallelPipeline class orchestrates the logic of handling user queries, + generating responses and managing several aspects of a chat assistant. It + enables feature storage, language model client setup, time scheduling and + much more. + + Attributes: + llm: A ChatClient instance that communicates with the language model. + fs: An instance of FeatureStore for loading and querying features. + config_path: A string indicating the path of the configuration file. + config: A dictionary holding the configuration settings. + context_max_length: An integer representing the maximum length of the context used by the language model. # noqa E501 + + Several template strings for various prompts are also defined. + """ + + def __init__(self, work_dir: str, config_path: str): + """Constructs all the necessary attributes for the worker object. + + Args: + work_dir (str): The working directory where feature files are located. + config_path (str): The location of the configuration file. + """ + self.llm = ChatClient(config_path=config_path) + self.retriever = CacheRetriever(config_path=config_path).get() + + self.config_path = config_path + self.config = None + with open(config_path, encoding='utf8') as f: + self.config = pytoml.load(f) + if self.config is None: + raise Exception('worker config can not be None') + + + async def generate(self, + query: Union[Query, str], + history: List[Tuple[str]]=[], + language: str='zh', + enable_web_search: bool=True): + """Processes user queries and generates appropriate responses. It + involves several steps including checking for valid questions, + extracting topics, querying the feature store, searching the web, and + generating responses from the language model. + + Args: + query (Union[Query,str]): User's multimodal query. + history (str): Chat history. + groupname (str): The group name in which user asked the query. + groupchats (List[str]): The history conversation in group before user query. + + Returns: + Session: Sync generator, this function would yield session which contains: + ErrorCode: An error code indicating the status of response generation. # noqa E501 + str: Generated response to the user query. + references: List for referenced filename or web url + """ + # format input + if type(query) is str: + query = Query(text=query) + + # build input session + sess = Session(query=query, + history=history, + log_path=self.config['worker']['save_path']) + + # build pipeline + preproc = PreprocNode(self.config, self.llm, language) + text2vec = Text2vecRetrieval(self.config, self.llm, self.retriever, language) + websearch = WebSearchRetrieval(self.config, self.config_path, self.llm, language) + reduce = ReduceGenerate(self.config, self.llm, self.retriever, language) + pipeline = [preproc, [text2vec, websearch], reduce] + + direct_chat_states = [ + ErrorCode.QUESTION_TOO_SHORT, ErrorCode.NOT_A_QUESTION, + ErrorCode.NO_TOPIC, ErrorCode.UNRELATED + ] + + # if not a good question, return + for sess in preproc.process(sess): + if sess.code in direct_chat_states: + async for resp in reduce.process(sess): + yield resp + return + + # parallel run text2vec and websearch + + tasks = [text2vec.process_coroutine(copy.deepcopy(sess))] + if enable_web_search: + tasks.append(websearch.process_coroutine(copy.deepcopy(sess))) + + task_results = await asyncio.gather(*tasks, return_exceptions=True) + for result in task_results: + if type(result) is Session: + sess.parallel_chunks += result.parallel_chunks + continue + logger.error(result) + + async for sess in reduce.process(sess): + yield sess + return + + +def parse_args(): + """Parses command-line arguments.""" + parser = argparse.ArgumentParser(description='SerialPipeline.') + parser.add_argument('work_dir', type=str, help='Working directory.') + parser.add_argument( + '--config_path', + default='config.ini', + help='SerialPipeline configuration path. Default value is config.ini') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + bot = ParallelPipeline(work_dir=args.work_dir, config_path=args.config_path) + loop = asyncio.get_event_loop() + queries = ['茴香豆是什么?', 'HuixiangDou 是什么?'] + + for q in queries: + async def wrap_async_as_coroutine(): + async for sess in bot.generate(query=q, history=[], enable_web_search=False): + print(sess.delta, end='', flush=True) + pass + print('\n') + print(sess.references) + loop.run_until_complete(wrap_async_as_coroutine()) diff --git a/huixiangdou/service/retriever.py b/huixiangdou/service/retriever.py index 9909c124..935495d8 100644 --- a/huixiangdou/service/retriever.py +++ b/huixiangdou/service/retriever.py @@ -4,14 +4,14 @@ import os import pdb import time -from typing import Any, Union, Tuple +from typing import Any, Union, Tuple, List import numpy as np import pytoml from loguru import logger from sklearn.metrics import precision_recall_curve -from huixiangdou.primitive import Embedder, Faiss, LLMReranker, Query +from huixiangdou.primitive import Embedder, Faiss, LLMReranker, Query, Chunk from ..primitive import FileOperation from .helper import QueryTracker @@ -83,37 +83,18 @@ def update_throttle(self, f'The optimal threshold is: {optimal_threshold}, saved it to {config_path}' # noqa E501 ) - def query(self, - query: Union[Query, str], - context_max_length: int = 40000, - tracker: QueryTracker = None): - """Processes a query and returns the best match from the vector store - database. If the question is rejected, returns None. - + def text2vec_retrieve(self, query: Union[Query, str]): + """Retrieve chunks by text2vec model or knowledge graph. + Args: query (Query): The multimodal question asked by the user. - context_max_length (int): Max contenxt length for LLM. - tracker (QueryTracker): Log tracker. - + Returns: - str: Matched chunks, or empty string - str: Matched context from origin file content - List[str]: References + List[Chunk]: ref chunks. """ if type(query) is str: query = Query(text=query) - if query.text is None or len(query.text) < 1 or self.faiss is None: - return None, None, [] - - if len(query.text) > 512: - logger.warning('input too long, truncate to 512') - query.text = query.text[0:512] - - chunks = [] - context = '' - references = [] - graph_delta = 0.0 if self.kg.is_available(): try: @@ -125,45 +106,56 @@ def query(self, threshold = self.reject_throttle - graph_delta pairs = self.faiss.similarity_search_with_query(self.embedder, - query=query) - # logger.debug('retriever.docs {}'.format(docs)) - - if len(pairs) < 1 or pairs[0][1] < threshold: - references.append(pairs[0][0].metadata['source']) - return None, None, references + query=query, threshold=threshold) + chunks = [pair[0] for pair in pairs] + return chunks - high_score_chunks = [] - for pair in pairs: - if pair[1] >= threshold: - high_score_chunks.append(pair[0]) + def rerank_fuse(self, query: Union[Query, str], chunks: List[Chunk], context_max_length:int): + """Rerank chunks and extract content + + Args: + chunks (List[Chunk]): filtered chunks. + + Returns: + str: Joined chunks, or empty string + str: Matched context from origin file content + List[str]: References + """ + if type(query) is str: + query = Query(text=query) - chunks = self.reranker.rerank(query=query.text, - chunks=high_score_chunks) - if tracker is not None: - tracker.log('retrieve', [c.metadata['source'] for c in chunks]) + rerank_chunks = self.reranker.rerank(query=query.text, + chunks=chunks) file_opr = FileOperation() - splits = [] # Add file text to context, until exceed `context_max_length` # If `file_length` > `context_max_length` (for example file_length=300 and context_max_length=100) # then centered on the chunk, read a length of 200 - for idx, chunk in enumerate(chunks): + splits = [] + context = '' + references = [] + for idx, chunk in enumerate(rerank_chunks): content = chunk.content_or_path splits.append(content) - file_text, error = file_opr.read(chunk.metadata['read']) - if error is not None: - # read file failed, skip - continue - source = chunk.metadata['source'] - logger.info('target {} file length {}'.format( + if '://' in source: + # url + file_text = content + else: + file_text, error = file_opr.read(chunk.metadata['read']) + if error is not None: + # read file failed, skip + continue + + logger.info('target {} content length {}'.format( source, len(file_text))) if len(file_text) + len(context) > context_max_length: if source in references: continue references.append(source) + # add and break add_len = context_max_length - len(context) if add_len <= 0: @@ -191,6 +183,39 @@ def query(self, os.path.basename(r) for r in references ] + def query(self, + query: Union[Query, str], + context_max_length: int = 40000, + tracker: QueryTracker = None): + """Processes a query and returns the best match from the vector store + database. If the question is rejected, returns None. + + Args: + query (Query): The multimodal question asked by the user. + context_max_length (int): Max contenxt length for LLM. + tracker (QueryTracker): Log tracker. + + Returns: + str: Matched chunks, or empty string + str: Matched context from origin file content + List[str]: References + """ + if type(query) is str: + query = Query(text=query) + + if query.text is None or len(query.text) < 1 or self.faiss is None: + return None, None, [] + + if len(query.text) > 512: + logger.warning('input too long, truncate to 512') + query.text = query.text[0:512] + + high_score_chunks = self.text2vec_retrieve(query=query) + if tracker is not None: + tracker.log('retrieve', [c.metadata['source'] for c in high_score_chunks]) + + return self.rerank_fuse(query=query, chunks=high_score_chunks, context_max_length=context_max_length) + def is_relative(self, query, k=30, diff --git a/huixiangdou/service/worker.py b/huixiangdou/service/serial_pipeline.py similarity index 86% rename from huixiangdou/service/worker.py rename to huixiangdou/service/serial_pipeline.py index 270a9dd5..552997ae 100644 --- a/huixiangdou/service/worker.py +++ b/huixiangdou/service/serial_pipeline.py @@ -12,7 +12,6 @@ import pytoml from loguru import logger -from openai import OpenAI from huixiangdou.primitive import Query @@ -20,67 +19,18 @@ from .llm_client import ChatClient from .retriever import CacheRetriever, Retriever from .sg_search import SourceGraphProxy +from .session import Session from .web_search import WebSearch from .prompt import (SCORING_QUESTION_TEMPLTE_CN, CR_NEED_CN, CR_CN, TOPIC_TEMPLATE_CN, SCORING_RELAVANCE_TEMPLATE_CN, GENERATE_TEMPLATE_CN, KEYWORDS_TEMPLATE_CN, PERPLESITY_TEMPLATE_CN, SECURITY_TEMAPLTE_CN) from .prompt import (SCORING_QUESTION_TEMPLTE_EN, CR_NEED_EN, CR_EN, TOPIC_TEMPLATE_EN, SCORING_RELAVANCE_TEMPLATE_EN, GENERATE_TEMPLATE_EN, KEYWORDS_TEMPLATE_EN, PERPLESITY_TEMPLATE_EN, SECURITY_TEMAPLTE_EN) -class Session: - """For compute graph, `session` takes all parameter.""" - - def __init__(self, - query: Query, - history: list, - groupname: str, - log_path: str = 'logs/generate.jsonl', - groupchats: list = []): - self.stage = 'init' - self.query = query - self.history = history - self.groupname = groupname - self.groupchats = groupchats - - # init - self.response = '' - self.references = [] - self.topic = '' - self.code = ErrorCode.INIT - - # coreference resolution results - self.cr = '' - - # text2vec results - self.chunk = '' - self.knowledge = '' - - # web search results - self.web_knowledge = '' - - # source graph search results - self.sg_knowledge = '' - - # debug logs - self.debug = dict() - self.log_path = log_path - - def __del__(self): - dirname = os.path.dirname(self.log_path) - if not os.path.exists(dirname): - os.makedirs(dirname) - - with open(self.log_path, 'a') as f: - json_str = json.dumps(self.debug, indent=2, ensure_ascii=False) - f.write(json_str) - f.write('\n') - - class Node(ABC): - """Base abstractfor compute graph.""" + """Base abstract for compute graph.""" @abstractmethod def process(self, sess: Session) -> Generator[Session, None, None]: pass - class PreprocNode(Node): """PreprocNode is for coreference resolution and scoring based on group chats. @@ -91,22 +41,16 @@ class PreprocNode(Node): def __init__(self, config: dict, llm: ChatClient, language: str): self.llm = llm self.enable_cr = config['worker']['enable_cr'] - self.cr_client = OpenAI( - base_url=config['coreference_resolution']['base_url'], - api_key=config['coreference_resolution']['api_key']) if language == 'zh': self.SCORING_QUESTION_TEMPLTE = SCORING_QUESTION_TEMPLTE_CN - self.CR_NEED = CR_NEED_CN self.CR = CR_CN else: self.SCORING_QUESTION_TEMPLTE = SCORING_QUESTION_TEMPLTE_EN - self.CR_NEED = CR_NEED_EN self.CR = CR_EN def process(self, sess: Session) -> Generator[Session, None, None]: # check input - sess.stage = str(type(self).__name__) if sess.query.text is None or len(sess.query.text) < 6: sess.code = ErrorCode.QUESTION_TOO_SHORT yield sess @@ -124,7 +68,6 @@ def process(self, sess: Session) -> Generator[Session, None, None]: return if not self.enable_cr: - yield sess return if len(sess.groupchats) < 1: @@ -145,36 +88,6 @@ def process(self, sess: Session) -> Generator[Session, None, None]: talks.append({'sender': name_map[sender], 'content': msg.query}) talk_str = json.dumps(talks, ensure_ascii=False) - prompt = self.CR_NEED.format(talk_str, sess.query.text) - - # need coreference resolution or not - response = '' - try: - completion = self.cr_client.chat.completions.create( - model='coref-res', - messages=[{ - 'role': 'user', - 'content': prompt - }]) - response = completion.choices[0].message.content.lower() - except Exception as e: - logger.error(str(e)) - yield sess - return - sess.debug['PreprocNode_need_cr'] = response - need_cr = False - - if response.startswith('b') or response == '需要': - need_cr = True - else: - for sentence in ['因此需要', '因此选择b', '需要进行指代消解', '需要指代消解', 'b:需要']: - if sentence in response: - need_cr = True - break - - if not need_cr: - yield sess - return prompt = self.CR.format(talk_str, sess.query.text) self.cr = self.llm.generate_response(prompt=prompt, backend='remote') @@ -218,7 +131,6 @@ def __init__(self, config: dict, llm: ChatClient, retriever: Retriever, def process(self, sess: Session) -> Generator[Session, None, None]: """Try get reply with text2vec & rerank model.""" - sess.stage = str(type(self).__name__) # get query topic prompt = self.TOPIC_TEMPLATE.format(sess.query.text) sess.topic = self.llm.generate_response(prompt) @@ -301,7 +213,6 @@ def process(self, sess: Session) -> Generator[Session, None, None]: logger.debug('disable web_search') yield sess return - sess.stage = str(type(self).__name__) engine = WebSearch(config_path=self.config_path) @@ -368,7 +279,6 @@ def process(self, sess: Session) -> Generator[Session, None, None]: logger.debug('disable sg_search') yield sess return - sess.stage = str(type(self).__name__) # if exit for other status (SECURITY or SEARCH_FAIL), still quit `sg_search` if sess.code != ErrorCode.BAD_ANSWER and sess.code != ErrorCode.NO_SEARCH_RESULT and sess.code != ErrorCode.WEB_SEARCH_FAIL: @@ -414,8 +324,6 @@ def __init__(self, llm: ChatClient, language: str): def process(self, sess: Session) -> Generator[Session, None, None]: """Check result with security.""" - sess.stage = str(type(self).__name__) - if len(sess.response) < 1: sess.code = ErrorCode.BAD_ANSWER yield sess @@ -443,8 +351,8 @@ def process(self, sess: Session) -> Generator[Session, None, None]: yield sess -class Worker: - """The Worker class orchestrates the logic of handling user queries, +class SerialPipeline: + """The SerialPipeline class orchestrates the logic of handling user queries, generating responses and managing several aspects of a chat assistant. It enables feature storage, language model client setup, time scheduling and much more. @@ -530,8 +438,8 @@ def work_time(self): def generate(self, query: Union[Query, str], - history: List, - groupname: str, + history: List[str] = [], + groupname: str = '', groupchats: List[str] = []): """Processes user queries and generates appropriate responses. It involves several steps including checking for valid questions, @@ -600,18 +508,18 @@ def generate(self, def parse_args(): """Parses command-line arguments.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='SerialPipeline.') parser.add_argument('work_dir', type=str, help='Working directory.') parser.add_argument( '--config_path', default='config.ini', - help='Worker configuration path. Default value is config.ini') + help='SerialPipeline configuration path. Default value is config.ini') return parser.parse_args() if __name__ == '__main__': args = parse_args() - bot = Worker(work_dir=args.work_dir, config_path=args.config_path) + bot = SerialPipeline(work_dir=args.work_dir, config_path=args.config_path) queries = ['茴香豆是怎么做的'] for example in queries: print(bot.generate(query=example, history=[], groupname='')) diff --git a/huixiangdou/service/session.py b/huixiangdou/service/session.py new file mode 100644 index 00000000..1c7c6ac8 --- /dev/null +++ b/huixiangdou/service/session.py @@ -0,0 +1,54 @@ +from huixiangdou.primitive import Query +from .helper import ErrorCode +import os +import json + +class Session: + """For compute graph, `session` takes all parameter.""" + + def __init__(self, + query: Query, + history: list, + groupname: str = '', + log_path: str = 'logs/generate.jsonl', + groupchats: list = []): + self.query = query + self.history = history + self.groupname = groupname + self.groupchats = groupchats + + # init + # Same as `chunk.choices[0].delta` + self.delta = '' + self.parallel_chunks = [] + self.response = '' + self.references = [] + self.topic = '' + self.code = ErrorCode.INIT + + # coreference resolution results + self.cr = '' + + # text2vec results + self.chunk = '' + self.knowledge = '' + + # web search results + self.web_knowledge = '' + + # source graph search results + self.sg_knowledge = '' + + # debug logs + self.debug = dict() + self.log_path = log_path + + def __del__(self): + dirname = os.path.dirname(self.log_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + with open(self.log_path, 'a') as f: + json_str = json.dumps(self.debug, indent=2, ensure_ascii=False) + f.write(json_str) + f.write('\n') diff --git a/huixiangdou/service/web_search.py b/huixiangdou/service/web_search.py index 32605cce..0961817f 100644 --- a/huixiangdou/service/web_search.py +++ b/huixiangdou/service/web_search.py @@ -80,7 +80,7 @@ class WebSearch: get(query: str, max_article=1): Searches with cache. If the query already exists in the cache, return the cached result. # noqa E501 """ - def __init__(self, config_path: str, retry: int = 1) -> None: + def __init__(self, config_path: str, retry: int = 1, language:str='zh') -> None: """Initializes the WebSearch object with the given config path and retry count.""" @@ -89,6 +89,7 @@ def __init__(self, config_path: str, retry: int = 1) -> None: config = pytoml.load(f) self.search_config = types.SimpleNamespace(**config['web_search']) self.retry = retry + self.language = language def fetch_url(self, query: str, target_link: str, brief: str = ''): if not target_link.startswith('http'): @@ -182,7 +183,11 @@ def google(self, query: str, max_article: int): """ url = 'https://google.serper.dev/search' - payload = json.dumps({'q': f'{query}', 'hl': 'zh-cn'}) + if 'zh' in self.language: + lang = 'zh-cn' + else: + lang = 'en' + payload = json.dumps({'q': f'{query}', 'hl': lang}) headers = { 'X-API-KEY': self.search_config.serper_x_api_key, 'Content-Type': 'application/json' @@ -191,7 +196,7 @@ def google(self, query: str, max_article: int): url, headers=headers, data=payload, - timeout=15) # noqa E501 + timeout=5) # noqa E501 jsonobj = json.loads(response.text) logger.debug(jsonobj) keys = self.search_config.domain_partial_order diff --git a/requirements.txt b/requirements.txt index 0ead912d..1422dfc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,4 @@ fastapi uvicorn termcolor opencv-python-headless -gradio \ No newline at end of file +gradio>=4.41 \ No newline at end of file diff --git a/sft/reconstruct_filter_annotate.py b/sft/reconstruct_filter_annotate.py index 50df6b59..cb99ca25 100644 --- a/sft/reconstruct_filter_annotate.py +++ b/sft/reconstruct_filter_annotate.py @@ -251,7 +251,6 @@ def metric(llm_type: str, else: unknow_count += 1 print(dt) - pdb.set_trace() dts.append(bool(1 - cr_need_gt)) assert len(gts) == len(dts) diff --git a/tests/test_bge_reranker.py b/tests/test_bge_reranker.py index 691030b0..b8c51392 100644 --- a/tests/test_bge_reranker.py +++ b/tests/test_bge_reranker.py @@ -21,7 +21,6 @@ print(scores) # [-8.1875, 5.26171875] import pdb -pdb.set_trace() # You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score scores = reranker.compute_score([ ['what is panda?', 'hi'], diff --git a/tests/test_build_milvus_and_filter.py b/tests/test_build_milvus_and_filter.py index 2cdd3040..4d00c61c 100644 --- a/tests/test_build_milvus_and_filter.py +++ b/tests/test_build_milvus_and_filter.py @@ -190,7 +190,6 @@ def calculate(chunk_size: int): col = init_milvus(col_name='test2', max_length_bytes=3 * chunk_size) subdocs = split_by_group(docs) - pdb.set_trace() for idx, docs in enumerate(subdocs): print('build step {}'.format(idx)) texts = [] @@ -209,7 +208,6 @@ def calculate(chunk_size: int): col.flush() except Exception as e: print(e) - pdb.set_trace() print('insert finished') # start = 0.4 diff --git a/tests/test_m3.py b/tests/test_m3.py index 12cc9f9c..1fa5fec5 100644 --- a/tests/test_m3.py +++ b/tests/test_m3.py @@ -12,7 +12,6 @@ import pdb -pdb.set_trace() embeddings_1 = model.encode(sentences_1, max_length=512)['dense_vecs'] embeddings_2 = model.encode(sentences_2)['dense_vecs'] similarity = embeddings_1 @ embeddings_2.T diff --git a/tests/test_query_gradio.py b/tests/test_query_gradio.py index 44e79b84..c34e8a85 100644 --- a/tests/test_query_gradio.py +++ b/tests/test_query_gradio.py @@ -10,11 +10,11 @@ from loguru import logger from huixiangdou.primitive import Query -from huixiangdou.service import ErrorCode, Worker, llm_serve, start_llm_server +from huixiangdou.service import ErrorCode, SerialPipeline, ParallelPipeline, llm_serve, start_llm_server def parse_args(): """Parse args.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='SerialPipeline Gradio WebUI.') parser.add_argument('--work_dir', type=str, default='workdir', @@ -23,7 +23,7 @@ def parse_args(): '--config_path', default='config.ini', type=str, - help='Worker configuration path. Default value is config.ini') + help='SerialPipeline configuration path. Default value is config.ini') parser.add_argument('--standalone', action='store_true', default=True, @@ -40,7 +40,7 @@ def get_reply(text, image): else: image_path = None - assistant = Worker(work_dir=args.work_dir, config_path=args.config_path) + assistant = SerialPipeline(work_dir=args.work_dir, config_path=args.config_path) query = Query(text, image_path) code, reply, references = assistant.generate(query=query, diff --git a/web/proxy/main.py b/web/proxy/main.py index 5991f368..a254cf3b 100644 --- a/web/proxy/main.py +++ b/web/proxy/main.py @@ -19,7 +19,7 @@ feature_store_base_dir, parse_json_str, redis_host, redis_passwd, redis_port) -from .web_worker import WebWorker +from .web_worker import OpenXLabWorker def callback_task_state(feature_store_id: str, @@ -140,7 +140,7 @@ def chat_with_featue_store(cache: CacheRetriever, config_path=configpath, work_dir=workdir) - worker = WebWorker(work_dir=workdir, config_path=configpath) + worker = OpenXLabWorker(work_dir=workdir, config_path=configpath) history = format_history(payload.history) query_log = '{} {}\n'.format(fs_id, payload.content) diff --git a/web/proxy/web_worker.py b/web/proxy/web_worker.py index b464e401..49c76bbe 100644 --- a/web/proxy/web_worker.py +++ b/web/proxy/web_worker.py @@ -58,8 +58,8 @@ def openxlab_security(query: str, retry=1): return False -class WebWorker: - """The Worker class orchestrates the logic of handling user queries, +class OpenXLabWorker: + """The OpenXLab Worker class orchestrates the logic of handling user queries, generating responses and managing several aspects of a chat assistant. It enables feature storage, language model client setup, time scheduling and much more. @@ -329,18 +329,18 @@ def generate(self, query, history, retriever, groupname): def parse_args(): """Parses command-line arguments.""" - parser = argparse.ArgumentParser(description='Worker.') + parser = argparse.ArgumentParser(description='OpenXLabWorker.') parser.add_argument('work_dir', type=str, help='Working directory.') parser.add_argument( '--config_path', default='config.ini', - help='Worker configuration path. Default value is config.ini') + help='OpenXLabWorker configuration path. Default value is config.ini') return parser.parse_args() if __name__ == '__main__': args = parse_args() - bot = Worker(work_dir=args.work_dir, config_path=args.config_path) + bot = OpenXLabWorker(work_dir=args.work_dir, config_path=args.config_path) queries = ['茴香豆是怎么做的'] for example in queries: print(bot.generate(query=example, history=[], groupname=''))