Skip to content

Commit

Permalink
Merge pull request #12 from hitsz-ids/developing
Browse files Browse the repository at this point in the history
Developing
  • Loading branch information
lpxiangyan9 authored May 8, 2024
2 parents 64fcfcb + a3cc78b commit 9f20a04
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 62 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,24 @@ docker run -itd --name mongo -v /{path_of_mongo_data}:/data/db -p 27017:27017 mo

环境变量

下载https://github.com/hitsz-ids/airda/blob/main/.env.template文件,自定义embedding模型,mongo配置,以及openai配置
下载[.env.template](https://github.com/hitsz-ids/airda/blob/main/.env.template)自定义embedding模型,mongo配置,以及openai配置

```
airda env load -p {your_path}/.env_template
```

日志文件(非必须)

下载https://github.com/hitsz-ids/airda/blob/main/log_config.yml.template文件,自定义日志配置
下载[log_config.yml.template](https://github.com/hitsz-ids/airda/blob/main/log_config.yml.template),自定义日志配置

```
airda log load -p {your_path}/log_config.yml.template
```

Embedding Model

airda默认使用[stella-large-zh-v2](https://huggingface.co/infgrad/stella-large-zh-v2)模型, 模型默认下载到~/.cache/huggingface/hub/路径,目录下没有需要手动下载



### 相关配置命令
Expand Down Expand Up @@ -100,9 +104,8 @@ airda run cli -n {datasource_name}

我们欢迎各种贡献和建议,共同努力,使本项目更上一层楼!麻烦遵循以下步骤:

- **步骤1:** 如果您想添加任何额外的功能、增强功能或在使用过程中遇到任何问题,请发布一个 [问题](https://github.com/hitsz-ids/SQLAgent/issues) 。如果您能遵循 [问题模板](https://github.com/hitsz-ids/SQLAgent/issues/1) 我们将不胜感激。问题将在那里被讨论和分配。
- **步骤2:** 无论何时,当一个问题被分配后,您都可以按照 [PR模板](https://github.com/hitsz-ids/SQLAgent/pulls) 创建一个 [拉取请求](https://github.com/hitsz-ids/SQLAgent/pulls) 进行贡献。您也可以认领任何公开的问题。共同努力,我们可以使airda变得更好!
- **步骤1:** 如果您想添加任何额外的功能、增强功能或在使用过程中遇到任何问题,请发布一个 [问题](https://github.com/hitsz-ids/airda/issues) 。如果您能遵循 [问题模板](https://github.com/hitsz-ids/aird/issues/1) 我们将不胜感激。问题将在那里被讨论和分配。
- **步骤2:** 无论何时,当一个问题被分配后,您都可以按照 [PR模板](https://github.com/hitsz-ids/aird/pulls) 创建一个 [拉取请求](https://github.com/hitsz-ids/aird/pulls) 进行贡献。您也可以认领任何公开的问题。共同努力,我们可以使airda变得更好!
- **步骤3:** 在审查和讨论后,PR将被合并或迭代。感谢您的贡献!

在您开始之前,我们强烈建议您花一点时间检查 [这里](https://github.com/hitsz-ids/SQLAgent/blob/developing/CONTRIBUTING.md) 再进行贡献。

在您开始之前,我们强烈建议您花一点时间检查 [这里](https://github.com/hitsz-ids/aird/blob/developing/CONTRIBUTING.md) 再进行贡献。
6 changes: 3 additions & 3 deletions airda/cli/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from airda.agent.storage.entity.datasource import Datasource, Kind
from airda.agent.storage.repositories.datasource_repository import DatasourceRepository
from airda.connector.mysql import MysqlConnector
from airda.server.agent_server import DataAgentServer
from airda.server.agent_server.airda_server import AirdaServer

style = Style.from_dict(
{
Expand Down Expand Up @@ -108,8 +108,8 @@ async def execute():
help="服务端口号",
)
def server(port: int):
data_agent_server = DataAgentServer(port=port)
data_agent_server.run_server()
airda_server = AirdaServer(port=port)
airda_server.run_server()
pass


Expand Down
1 change: 1 addition & 0 deletions airda/connector/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ def query_schema(self):
table_comment=table[1],
)
self.context.sync_instruction(instruction)
cursor.close()
40 changes: 0 additions & 40 deletions airda/server/agent_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +0,0 @@
import fastapi
from overrides import overrides

from airda.agent.env import DataAgentEnv
from airda.server import WebFrameworkServer
from airda.server.api.api import APIImpl
from airda.server.protocol import ChatCompletionRequest


class DataAgentServer(WebFrameworkServer):
def __init__(self, host="0.0.0.0", port=8888):
super().__init__(host, port)
self.router = None

def init_api(self):
return APIImpl()

@overrides
def create_app(self):
return fastapi.FastAPI(debug=True)

@overrides
def run_server(self):
import uvicorn

uvicorn.run(self.app, host=self.host, port=self.port, log_level="info")

@overrides
def add_routes(self):
self.router = fastapi.APIRouter()
self.router.add_api_route(
"/v1/chat/completions",
self.create_completion,
methods=["POST"],
tags=["chat completions"],
)
self.app.include_router(self.router)

async def create_completion(self, request: ChatCompletionRequest):
return await self._api.create_completion(request)
48 changes: 48 additions & 0 deletions airda/server/agent_server/airda_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import fastapi
from overrides import overrides

from airda.server import WebFrameworkServer
from airda.server.api.api import APIImpl
from airda.server.protocol import AddDatasourceRequest, ChatCompletionRequest


class AirdaServer(WebFrameworkServer):
def __init__(self, host="0.0.0.0", port=8888):
super().__init__(host, port)
self.router = None

def init_api(self):
return APIImpl()

@overrides
def create_app(self):
return fastapi.FastAPI(debug=True)

@overrides
def run_server(self):
import uvicorn

uvicorn.run(self.app, host=self.host, port=self.port, log_level="info")

@overrides
def add_routes(self):
self.router = fastapi.APIRouter()
self.router.add_api_route(
"/v1/chat/completions",
self.create_completion,
methods=["POST"],
tags=["chat completions"],
)
self.router.add_api_route(
"/v1/datasource/add",
self.create_completion,
methods=["POST"],
tags=["datasource add"],
)
self.app.include_router(self.router)

async def create_completion(self, request: ChatCompletionRequest):
return await self._api.create_completion(request)

def add_datasource(self, request: AddDatasourceRequest):
return self._api.add_datasource(request)
6 changes: 5 additions & 1 deletion airda/server/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from abc import ABC, abstractmethod

from airda.server.protocol import ChatCompletionRequest
from airda.server.protocol import AddDatasourceRequest, ChatCompletionRequest


class API(ABC):
@abstractmethod
async def create_completion(self, request: ChatCompletionRequest):
pass

@abstractmethod
async def add_datasource(self, request: AddDatasourceRequest):
pass
42 changes: 37 additions & 5 deletions airda/server/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
from typing import AsyncGenerator

from fastapi.responses import JSONResponse, StreamingResponse
from overrides import override
from overrides import overrides

from airda.agent.agent import DataAgent
from airda.agent.data_agent_context import DataAgentContext
from airda.agent.exception.already_exists_error import AlreadyExistsError
from airda.agent.planner.data_agent_planner_params import DataAgentPlannerParams
from airda.agent.storage import StorageKey
from airda.agent.storage.entity.datasource import Datasource, Kind
from airda.agent.storage.repositories.datasource_repository import DatasourceRepository
from airda.server.api import API
from airda.server.protocol import (
AddDatasourceRequest,
ChatCompletionRequest,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
Expand All @@ -22,24 +27,51 @@

class APIImpl(API):
_cache: dict[str, StreamingResponse] = {}
agent: DataAgentContext
context: DataAgentContext

def __init__(self):
super().__init__()
self.agent = DataAgent().run()
self.context = DataAgent().run()

@override
@overrides
async def create_completion(self, request: ChatCompletionRequest):
async def stream_generator() -> AsyncGenerator[str, None]:
try:
pipeline = self.agent.get_planner().plan(DataAgentPlannerParams(**vars(request)))
pipeline = self.context.get_planner().plan(DataAgentPlannerParams(**vars(request)))
async for item in pipeline.execute():
yield f"data: {make_stream_data(content=item)}\n\n"
except Exception as e:
print(e)

return StreamingResponse(stream_generator(), media_type="text/event-stream")

@overrides
def add_datasource(self, request: AddDatasourceRequest):
kind = Kind.getKind(request.kind)
if kind is None:
# output_colored_text(f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]", "error")
message = f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]"
return JSONResponse(ErrorResponse(message=message, code=-1).dict())
datasource_repository = self.context.get_repository(StorageKey.DATASOURCE).convert(
DatasourceRepository
)
try:
datasource_repository.add(
Datasource(
name=request.name,
host=request.host,
port=request.port,
database=request.database,
kind=kind,
username=request.username,
password=request.password,
)
)
# output_colored_text("执行成功", "success")
except AlreadyExistsError:
pass
# output_colored_text(f"执行失败, [{name}]数据源已存在", "error")


def make_stream_data(
content: str | dict | list,
Expand Down
16 changes: 10 additions & 6 deletions airda/server/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ class ErrorResponse(BaseModel):

class ChatCompletionRequest(BaseModel):
question: str
datasource_id: str
datasource_name: str


class AddDatasourceRequest(BaseModel):
name: str
host: str
port: int
database: str
knowledge: str
session_id: str
sql_type: str = "mysql"
file_name: str
file_id: str
kind: str
username: str | None
password: str | None


class DeltaMessage(BaseModel):
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ torch = "2.0.1"
pymongo = "^4.6.2"
prompt-toolkit = "^3.0.43"
pyyaml = "^6.0.1"
mysql-connector-python = "^8.3.0"
fastapi = "0.99.0"

[tool.poetry.scripts]
Expand Down

0 comments on commit 9f20a04

Please sign in to comment.