From ae957f1393449e2ff8a435857251eb9b2962943b Mon Sep 17 00:00:00 2001 From: iokk3732 <141700052+iokk3732@users.noreply.github.com> Date: Fri, 26 Apr 2024 19:49:03 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BB=A5http=E4=B8=BA?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=90=AF=E5=8A=A8=E7=9A=84=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- airda/server/agent_server/__init__.py | 40 ------------------ airda/server/agent_server/airda_server.py | 48 +++++++++++++++++++++ airda/server/api/__init__.py | 6 ++- airda/server/api/api.py | 51 ++++++++++++++++++----- airda/server/protocol/__init__.py | 16 ++++--- 5 files changed, 103 insertions(+), 58 deletions(-) create mode 100644 airda/server/agent_server/airda_server.py diff --git a/airda/server/agent_server/__init__.py b/airda/server/agent_server/__init__.py index 277a7ae..e69de29 100644 --- a/airda/server/agent_server/__init__.py +++ b/airda/server/agent_server/__init__.py @@ -1,40 +0,0 @@ -import fastapi -from overrides import overrides - -from airda.agent.env import DataAgentEnv -from airda.server import WebFrameworkServer -from airda.server.api.api import APIImpl -from airda.server.protocol import ChatCompletionRequest - - -class DataAgentServer(WebFrameworkServer): - def __init__(self, host="0.0.0.0", port=8888): - super().__init__(host, port) - self.router = None - - def init_api(self): - return APIImpl() - - @overrides - def create_app(self): - return fastapi.FastAPI(debug=True) - - @overrides - def run_server(self): - import uvicorn - - uvicorn.run(self.app, host=self.host, port=self.port, log_level="info") - - @overrides - def add_routes(self): - self.router = fastapi.APIRouter() - self.router.add_api_route( - "/v1/chat/completions", - self.create_completion, - methods=["POST"], - tags=["chat completions"], - ) - self.app.include_router(self.router) - - async def create_completion(self, request: ChatCompletionRequest): - return await self._api.create_completion(request) diff --git a/airda/server/agent_server/airda_server.py b/airda/server/agent_server/airda_server.py new file mode 100644 index 0000000..a9cd8f2 --- /dev/null +++ b/airda/server/agent_server/airda_server.py @@ -0,0 +1,48 @@ +import fastapi +from overrides import overrides + +from airda.server import WebFrameworkServer +from airda.server.api.api import APIImpl +from airda.server.protocol import ChatCompletionRequest, AddDatasourceRequest + + +class AirdaServer(WebFrameworkServer): + def __init__(self, host="0.0.0.0", port=8888): + super().__init__(host, port) + self.router = None + + def init_api(self): + return APIImpl() + + @overrides + def create_app(self): + return fastapi.FastAPI(debug=True) + + @overrides + def run_server(self): + import uvicorn + + uvicorn.run(self.app, host=self.host, port=self.port, log_level="info") + + @overrides + def add_routes(self): + self.router = fastapi.APIRouter() + self.router.add_api_route( + "/v1/chat/completions", + self.create_completion, + methods=["POST"], + tags=["chat completions"], + ) + self.router.add_api_route( + "/v1/datasource/add", + self.create_completion, + methods=["POST"], + tags=["datasource add"], + ) + self.app.include_router(self.router) + + async def create_completion(self, request: ChatCompletionRequest): + return await self._api.create_completion(request) + + def add_datasource(self, request: AddDatasourceRequest): + return self._api.add_datasource(request) diff --git a/airda/server/api/__init__.py b/airda/server/api/__init__.py index 5878d8d..bf39a1c 100644 --- a/airda/server/api/__init__.py +++ b/airda/server/api/__init__.py @@ -1,9 +1,13 @@ from abc import ABC, abstractmethod -from airda.server.protocol import ChatCompletionRequest +from airda.server.protocol import ChatCompletionRequest, AddDatasourceRequest class API(ABC): @abstractmethod async def create_completion(self, request: ChatCompletionRequest): pass + + @abstractmethod + async def add_datasource(self, request: AddDatasourceRequest): + pass diff --git a/airda/server/api/api.py b/airda/server/api/api.py index 3fa0bce..e25d8e0 100644 --- a/airda/server/api/api.py +++ b/airda/server/api/api.py @@ -3,18 +3,22 @@ from typing import AsyncGenerator from fastapi.responses import JSONResponse, StreamingResponse -from overrides import override +from overrides import overrides from airda.agent.agent import DataAgent from airda.agent.data_agent_context import DataAgentContext +from airda.agent.exception.already_exists_error import AlreadyExistsError from airda.agent.planner.data_agent_planner_params import DataAgentPlannerParams +from airda.agent.storage import StorageKey +from airda.agent.storage.entity.datasource import Kind, Datasource +from airda.agent.storage.repositories.datasource_repository import DatasourceRepository from airda.server.api import API from airda.server.protocol import ( ChatCompletionRequest, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, DeltaMessage, - ErrorResponse, + ErrorResponse, AddDatasourceRequest, ) logger = logging.getLogger(__name__) @@ -22,17 +26,17 @@ class APIImpl(API): _cache: dict[str, StreamingResponse] = {} - agent: DataAgentContext + context: DataAgentContext def __init__(self): super().__init__() - self.agent = DataAgent().run() + self.context = DataAgent().run() - @override + @overrides async def create_completion(self, request: ChatCompletionRequest): async def stream_generator() -> AsyncGenerator[str, None]: try: - pipeline = self.agent.get_planner().plan(DataAgentPlannerParams(**vars(request))) + pipeline = self.context.get_planner().plan(DataAgentPlannerParams(**vars(request))) async for item in pipeline.execute(): yield f"data: {make_stream_data(content=item)}\n\n" except Exception as e: @@ -40,13 +44,38 @@ async def stream_generator() -> AsyncGenerator[str, None]: return StreamingResponse(stream_generator(), media_type="text/event-stream") + @overrides + def add_datasource(self, request: AddDatasourceRequest): + kind = Kind.getKind(request.kind) + if kind is None: + # output_colored_text(f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]", "error") + message = f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]" + return JSONResponse(ErrorResponse(message=message, code=-1).dict()) + datasource_repository = self.context.get_repository(StorageKey.DATASOURCE).convert(DatasourceRepository) + try: + datasource_repository.add( + Datasource( + name=request.name, + host=request.host, + port=request.port, + database=request.database, + kind=kind, + username=request.username, + password=request.password, + ) + ) + # output_colored_text("执行成功", "success") + except AlreadyExistsError: + pass + # output_colored_text(f"执行失败, [{name}]数据源已存在", "error") + def make_stream_data( - content: str | dict | list, - rep_type: str = "stream", - model: str = "gpt-4-1106-preview", - finish_reason: str = "", - session_id: str = "", + content: str | dict | list, + rep_type: str = "stream", + model: str = "gpt-4-1106-preview", + finish_reason: str = "", + session_id: str = "", ): push_json = {"type": rep_type, "data": content} choice_data = ChatCompletionResponseStreamChoice( diff --git a/airda/server/protocol/__init__.py b/airda/server/protocol/__init__.py index 9992e4e..6c1971f 100644 --- a/airda/server/protocol/__init__.py +++ b/airda/server/protocol/__init__.py @@ -13,13 +13,17 @@ class ErrorResponse(BaseModel): class ChatCompletionRequest(BaseModel): question: str - datasource_id: str + datasource_name: str + + +class AddDatasourceRequest(BaseModel): + name: str + host: str + port: int database: str - knowledge: str - session_id: str - sql_type: str = "mysql" - file_name: str - file_id: str + kind: str + username: str | None + password: str | None class DeltaMessage(BaseModel): From bc6829ea9acf26211fc6dfb63de2253d8c729909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=B9=8F?= <692178807@qq.com> Date: Sun, 28 Apr 2024 10:08:41 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E6=A8=A1=E5=9E=8B=E4=B8=8B=E8=BD=BD=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a32923f..9af6e9c 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,10 @@ airda env load -p {your_path}/.env_template airda log load -p {your_path}/log_config.yml.template ``` +Embedding Model + +airda默认使用[stella-large-zh-v2](https://huggingface.co/infgrad/stella-large-zh-v2)模型, 模型默认下载到~/.cache/huggingface/hub/路径,目录下没有需要手动下载 + ### 相关配置命令 @@ -100,9 +104,8 @@ airda run cli -n {datasource_name} 我们欢迎各种贡献和建议,共同努力,使本项目更上一层楼!麻烦遵循以下步骤: -- **步骤1:** 如果您想添加任何额外的功能、增强功能或在使用过程中遇到任何问题,请发布一个 [问题](https://github.com/hitsz-ids/SQLAgent/issues) 。如果您能遵循 [问题模板](https://github.com/hitsz-ids/SQLAgent/issues/1) 我们将不胜感激。问题将在那里被讨论和分配。 -- **步骤2:** 无论何时,当一个问题被分配后,您都可以按照 [PR模板](https://github.com/hitsz-ids/SQLAgent/pulls) 创建一个 [拉取请求](https://github.com/hitsz-ids/SQLAgent/pulls) 进行贡献。您也可以认领任何公开的问题。共同努力,我们可以使airda变得更好! +- **步骤1:** 如果您想添加任何额外的功能、增强功能或在使用过程中遇到任何问题,请发布一个 [问题](https://github.com/hitsz-ids/airda/issues) 。如果您能遵循 [问题模板](https://github.com/hitsz-ids/aird/issues/1) 我们将不胜感激。问题将在那里被讨论和分配。 +- **步骤2:** 无论何时,当一个问题被分配后,您都可以按照 [PR模板](https://github.com/hitsz-ids/aird/pulls) 创建一个 [拉取请求](https://github.com/hitsz-ids/aird/pulls) 进行贡献。您也可以认领任何公开的问题。共同努力,我们可以使airda变得更好! - **步骤3:** 在审查和讨论后,PR将被合并或迭代。感谢您的贡献! -在您开始之前,我们强烈建议您花一点时间检查 [这里](https://github.com/hitsz-ids/SQLAgent/blob/developing/CONTRIBUTING.md) 再进行贡献。 - +在您开始之前,我们强烈建议您花一点时间检查 [这里](https://github.com/hitsz-ids/aird/blob/developing/CONTRIBUTING.md) 再进行贡献。 From 86d2bc92eb22e2e66f091337950c05732cb0a8d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=B9=8F?= <692178807@qq.com> Date: Mon, 29 Apr 2024 11:32:21 +0800 Subject: [PATCH 3/5] =?UTF-8?q?readme=20=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9af6e9c..4385831 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ docker run -itd --name mongo -v /{path_of_mongo_data}:/data/db -p 27017:27017 mo 环境变量 -下载https://github.com/hitsz-ids/airda/blob/main/.env.template文件,自定义embedding模型,mongo配置,以及openai配置 +下载[.env.template](https://github.com/hitsz-ids/airda/blob/main/.env.template)自定义embedding模型,mongo配置,以及openai配置 ``` airda env load -p {your_path}/.env_template @@ -58,7 +58,7 @@ airda env load -p {your_path}/.env_template 日志文件(非必须) -下载https://github.com/hitsz-ids/airda/blob/main/log_config.yml.template文件,自定义日志配置 +下载[log_config.yml.template](https://github.com/hitsz-ids/airda/blob/main/log_config.yml.template),自定义日志配置 ``` airda log load -p {your_path}/log_config.yml.template From 7ab583d15758b3c95b8035daac1c1d3d6c6b978d Mon Sep 17 00:00:00 2001 From: iokk3732 <141700052+iokk3732@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:13:04 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E5=90=8D=E7=A7=B0=E4=B8=BAAirdaServer,=20=E5=8F=96=E6=B6=88dat?= =?UTF-8?q?asource=E7=9A=84enable=E5=92=8Cdisable=E9=80=89=E9=A1=B9?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=AF=8F=E6=AC=A1=E8=BF=9B=E5=85=A5?= =?UTF-8?q?=E9=97=AE=E7=AD=94datasource=E7=9A=84=E6=8C=87=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- airda/cli/startup.py | 6 +++--- airda/connector/mysql.py | 1 + pyproject.toml | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airda/cli/startup.py b/airda/cli/startup.py index d25ed72..a63bd5c 100644 --- a/airda/cli/startup.py +++ b/airda/cli/startup.py @@ -18,7 +18,7 @@ from airda.agent.storage.entity.datasource import Datasource, Kind from airda.agent.storage.repositories.datasource_repository import DatasourceRepository from airda.connector.mysql import MysqlConnector -from airda.server.agent_server import DataAgentServer +from airda.server.agent_server.airda_server import AirdaServer style = Style.from_dict( { @@ -108,8 +108,8 @@ async def execute(): help="服务端口号", ) def server(port: int): - data_agent_server = DataAgentServer(port=port) - data_agent_server.run_server() + airda_server = AirdaServer(port=port) + airda_server.run_server() pass diff --git a/airda/connector/mysql.py b/airda/connector/mysql.py index 13cb992..310ad29 100644 --- a/airda/connector/mysql.py +++ b/airda/connector/mysql.py @@ -68,3 +68,4 @@ def query_schema(self): table_comment=table[1], ) self.context.sync_instruction(instruction) + cursor.close() diff --git a/pyproject.toml b/pyproject.toml index eb3b129..d42b1de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ torch = "2.0.1" pymongo = "^4.6.2" prompt-toolkit = "^3.0.43" pyyaml = "^6.0.1" -mysql-connector-python = "^8.3.0" fastapi = "0.99.0" [tool.poetry.scripts] From 5e7af5884f5f7c5472433aee9a5338fb5d47fe4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=B9=8F?= <692178807@qq.com> Date: Tue, 7 May 2024 15:40:10 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- airda/server/agent_server/airda_server.py | 2 +- airda/server/api/__init__.py | 2 +- airda/server/api/api.py | 19 +++++++++++-------- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/airda/server/agent_server/airda_server.py b/airda/server/agent_server/airda_server.py index a9cd8f2..e83d865 100644 --- a/airda/server/agent_server/airda_server.py +++ b/airda/server/agent_server/airda_server.py @@ -3,7 +3,7 @@ from airda.server import WebFrameworkServer from airda.server.api.api import APIImpl -from airda.server.protocol import ChatCompletionRequest, AddDatasourceRequest +from airda.server.protocol import AddDatasourceRequest, ChatCompletionRequest class AirdaServer(WebFrameworkServer): diff --git a/airda/server/api/__init__.py b/airda/server/api/__init__.py index bf39a1c..2664a60 100644 --- a/airda/server/api/__init__.py +++ b/airda/server/api/__init__.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from airda.server.protocol import ChatCompletionRequest, AddDatasourceRequest +from airda.server.protocol import AddDatasourceRequest, ChatCompletionRequest class API(ABC): diff --git a/airda/server/api/api.py b/airda/server/api/api.py index e25d8e0..3e76a8d 100644 --- a/airda/server/api/api.py +++ b/airda/server/api/api.py @@ -10,15 +10,16 @@ from airda.agent.exception.already_exists_error import AlreadyExistsError from airda.agent.planner.data_agent_planner_params import DataAgentPlannerParams from airda.agent.storage import StorageKey -from airda.agent.storage.entity.datasource import Kind, Datasource +from airda.agent.storage.entity.datasource import Datasource, Kind from airda.agent.storage.repositories.datasource_repository import DatasourceRepository from airda.server.api import API from airda.server.protocol import ( + AddDatasourceRequest, ChatCompletionRequest, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, DeltaMessage, - ErrorResponse, AddDatasourceRequest, + ErrorResponse, ) logger = logging.getLogger(__name__) @@ -51,7 +52,9 @@ def add_datasource(self, request: AddDatasourceRequest): # output_colored_text(f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]", "error") message = f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]" return JSONResponse(ErrorResponse(message=message, code=-1).dict()) - datasource_repository = self.context.get_repository(StorageKey.DATASOURCE).convert(DatasourceRepository) + datasource_repository = self.context.get_repository(StorageKey.DATASOURCE).convert( + DatasourceRepository + ) try: datasource_repository.add( Datasource( @@ -71,11 +74,11 @@ def add_datasource(self, request: AddDatasourceRequest): def make_stream_data( - content: str | dict | list, - rep_type: str = "stream", - model: str = "gpt-4-1106-preview", - finish_reason: str = "", - session_id: str = "", + content: str | dict | list, + rep_type: str = "stream", + model: str = "gpt-4-1106-preview", + finish_reason: str = "", + session_id: str = "", ): push_json = {"type": rep_type, "data": content} choice_data = ChatCompletionResponseStreamChoice(