Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAG and datasets API #279

Merged
merged 12 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ sqlalchemy = "2.0.29"
trio = "0.25.0"
uvicorn = "0.29.0"
openai = "1.35.7"
datasets = "2.18.0"


[tool.poetry.dev-dependencies]
Expand Down Expand Up @@ -66,6 +67,7 @@ uvicorn = "0.29.0"
openai = "1.35.7"
pre-commit="3.7.0"
pytest="8.1.1"
datasets = "2.18.0"


[build-system]
Expand Down
621 changes: 578 additions & 43 deletions backend/requirements.txt

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions backend/src/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
import fastapi

from src.api.routes.account import router as account_router
from src.api.routes.ai_model import router as ai_model_router
from src.api.routes.authentication import router as auth_router
from src.api.routes.chat import router as chat_router
from src.api.routes.file import router as file_router
from src.api.routes.train import router as train_router
from src.api.routes.version import router as version_router
from src.api.routes.health import router as health_router
from src.api.routes.rag_datasets import router as datasets_router

router = fastapi.APIRouter()

router.include_router(router=account_router)
router.include_router(router=auth_router)
router.include_router(router=chat_router)
router.include_router(router=ai_model_router)
router.include_router(router=train_router)
router.include_router(router=file_router)
router.include_router(router=version_router)
router.include_router(router=health_router)
router.include_router(router=datasets_router)
16 changes: 11 additions & 5 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,16 @@ async def chat(

match session.type:
case "rag":
# TODO: Implement RAG
pass
case _:
stream_func: ContentStream = rag_chat_repo.inference(
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_id=session.id,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
case _: # default is chat robot
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_id=session.id,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
Expand Down Expand Up @@ -307,7 +313,7 @@ async def get_chathistory(
```
"""
current_user = await account_repo.read_account_by_username(username=jwt_payload.username)
if session_repo.verify_session_by_account_id(session_uuid=uuid, account_id=current_user.id) is False:
if await session_repo.verify_session_by_account_id(session_uuid=uuid, account_id=current_user.id) is False:
Aisuko marked this conversation as resolved.
Show resolved Hide resolved
raise http_404_exc_uuid_not_found_request(uuid=uuid)
session = await session_repo.read_sessions_by_uuid(session_uuid=uuid)
chats = await chat_repo.read_chat_history_by_session_id(id=session.id)
Expand Down
83 changes: 83 additions & 0 deletions backend/src/api/routes/rag_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fastapi
from src.models.schemas.dataset import RagDatasetCreate, RagDatasetResponse
from src.repository.rag_datasets_eng import DatasetEng

router = fastapi.APIRouter(prefix="/ds", tags=["datasets"])


@router.get(
path="/list",
name="datasets:get-dataset-list",
response_model=list[RagDatasetResponse],
status_code=fastapi.status.HTTP_200_OK,
)
async def get_dataset_list() -> list[RagDatasetResponse]:
pass


@router.get(
path="/{name}",
name="datasets:get-dataset-by-name",
response_model=RagDatasetResponse,
status_code=fastapi.status.HTTP_200_OK,
)
async def get_dataset_by_name(name: str) -> RagDatasetResponse:
pass


@router.post(
path="/load",
name="datasets:load-dataset",
response_model=RagDatasetResponse,
status_code=fastapi.status.HTTP_201_CREATED,
)
async def load_dataset(
rag_ds_create: RagDatasetCreate,
) -> RagDatasetResponse:
"""

Loading the specific dataset into the vector db

curl -X 'POST' \
'http://127.0.0.1:8000/api/ds/load' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"name": "aisuko/squad01",
"des": "string",
"ratio": 0
}'

Return:
{
"name": "aisuko/squad01",
"status": "Success"
}
"""

res:dict=DatasetEng.load_dataset(rag_ds_create.name)

if res.get('insert_count')>0:
status="Success"
else:
status="Failed"

return RagDatasetResponse(
name=rag_ds_create.name,
status=status
)
1 change: 1 addition & 0 deletions backend/src/config/settings/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
MAX_SQL_LENGTH = 200

DEFAULT_COLLECTION = "default_collection"
# embedding dimension depending on model
DEFAULT_DIM = 384

# DEFAULT MODELS
Expand Down
29 changes: 29 additions & 0 deletions backend/src/models/schemas/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import datetime

from pydantic import Field
Expand All @@ -14,3 +30,16 @@ class DatasetResponse(BaseSchemaModel):
dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name")
created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")
updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time")


class RagDatasetCreate(BaseSchemaModel):
name: str = Field(..., title="DataSet Name", description="DataSet Name")
des: str | None = Field(..., title="Details", description="Details")
ratio: Optional[float] = Field(..., title="Ratio", description="Ratio")

class RagDatasetResponse(BaseSchemaModel):
name: str = Field(..., title="DataSet Name", description="DataSet Name")
# created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")
# updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time")
# ratio: Optional[float] = Field(..., title="Ratio", description="Ratio")
status: Optional[str] = Field(..., title="Status", description="Status")
15 changes: 15 additions & 0 deletions backend/src/repository/crud/vectors_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading
Loading