Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAG and datasets API #279

Merged
merged 12 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,21 @@ INFERENCE_ENG:=llamacpp
INFERENCE_ENG_PORT:=8080
INFERENCE_ENG_VERSION:=server--b1-2321a5e
NUM_CPU_CORES:=8.00
NUM_CPU_CORES_EMBEDDING:=4.00

# Embedding engine and it uses same version with Inference Engine
EMBEDDING_ENG:=embedding_eng
EMBEDDING_ENG_PORT:=8080

# Language model, default is phi3-mini-4k-instruct-q4.gguf
# https://github.com/SkywardAI/llama.cpp/blob/9b2f16f8055265c67e074025350736adc1ea0666/tests/test-chat-template.cpp#L91-L92
LANGUAGE_MODEL_NAME:=Phi-3-mini-4k-instruct-q4.gguf
LANGUAGE_MODEL_URL:=https://huggingface.co/aisuko/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi3-mini-4k-instruct-Q4.gguf?download=true
INSTRUCTION:="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the questions from human."

EMBEDDING_MODEL_NAME:=all-MiniLM-L6-v2-Q4_K_M-v2.gguf
EMBEDDING_MODEL_URL:=https://huggingface.co/aisuko/all-MiniLM-L6-v2-gguf/resolve/main/all-MiniLM-L6-v2-Q4_K_M-v2.gguf?download=true

ADMIN_USERNAME:=admin
ADMIN_EMAIL:[email protected]
ADMIN_PASS:=admin
Expand Down Expand Up @@ -120,13 +128,17 @@ env:
@echo "INFERENCE_ENG=$(INFERENCE_ENG)">> $(FILE_NAME)
@echo "INFERENCE_ENG_PORT=$(INFERENCE_ENG_PORT)">> $(FILE_NAME)
@echo "INFERENCE_ENG_VERSION=$(INFERENCE_ENG_VERSION)">> $(FILE_NAME)
@echo "EMBEDDING_ENG=$(EMBEDDING_ENG)">> $(FILE_NAME)
@echo "EMBEDDING_ENG_PORT=$(EMBEDDING_ENG_PORT)">> $(FILE_NAME)
@echo "NUM_CPU_CORES=$(NUM_CPU_CORES)">> $(FILE_NAME)
@echo "NUM_CPU_CORES_EMBEDDING=$(NUM_CPU_CORES_EMBEDDING)" >> $(FILE_NAME)
@echo "LANGUAGE_MODEL_NAME=$(LANGUAGE_MODEL_NAME)">> $(FILE_NAME)
@echo "ADMIN_USERNAME=$(ADMIN_USERNAME)">> $(FILE_NAME)
@echo "ADMIN_EMAIL=$(ADMIN_EMAIL)">> $(FILE_NAME)
@echo "ADMIN_PASS=$(ADMIN_PASS)">> $(FILE_NAME)
@echo "TIMEZONE=$(TIMEZONE)">> $(FILE_NAME)
@echo "INSTRUCTION"=$(INSTRUCTION)>> $(FILE_NAME)
@echo "EMBEDDING_MODEL_NAME"=$(EMBEDDING_MODEL_NAME) >> $(FILE_NAME)


.PHONY: prepare
Expand Down Expand Up @@ -195,6 +207,8 @@ ruff:
.PHONY: lm
lm:
@mkdir -p volumes/models && [ -f volumes/models/$(LANGUAGE_MODEL_NAME) ] || wget -O volumes/models/$(LANGUAGE_MODEL_NAME) $(LANGUAGE_MODEL_URL)
@wget -O volumes/models/$(EMBEDDING_MODEL_NAME) $(EMBEDDING_MODEL_URL)


.PHONY: localinfer
localinfer: lm
Expand Down
2 changes: 1 addition & 1 deletion backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ RUN pip --disable-pip-version-check --no-cache-dir install -r requirements.txt &

EXPOSE 8000

HEALTHCHECK --interval=60s --timeout=30s --retries=5 CMD ["curl", "-f", "http://localhost:8000/api/health"]
HEALTHCHECK --interval=300s --timeout=30s --retries=5 CMD ["curl", "-f", "http://localhost:8000/api/health"]

# Execute entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
Expand Down
2 changes: 2 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ sqlalchemy = "2.0.29"
trio = "0.25.0"
uvicorn = "0.29.0"
openai = "1.35.7"
datasets = "2.18.0"


[tool.poetry.dev-dependencies]
Expand Down Expand Up @@ -66,6 +67,7 @@ uvicorn = "0.29.0"
openai = "1.35.7"
pre-commit="3.7.0"
pytest="8.1.1"
datasets = "2.18.0"


[build-system]
Expand Down
621 changes: 578 additions & 43 deletions backend/requirements.txt

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions backend/src/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
import fastapi

from src.api.routes.account import router as account_router
from src.api.routes.ai_model import router as ai_model_router
from src.api.routes.authentication import router as auth_router
from src.api.routes.chat import router as chat_router
from src.api.routes.file import router as file_router
from src.api.routes.train import router as train_router
from src.api.routes.version import router as version_router
from src.api.routes.health import router as health_router
from src.api.routes.rag_datasets import router as datasets_router

router = fastapi.APIRouter()

router.include_router(router=account_router)
router.include_router(router=auth_router)
router.include_router(router=chat_router)
router.include_router(router=ai_model_router)
router.include_router(router=train_router)
router.include_router(router=file_router)
router.include_router(router=version_router)
router.include_router(router=health_router)
router.include_router(router=datasets_router)
16 changes: 11 additions & 5 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,16 @@ async def chat(

match session.type:
case "rag":
# TODO: Implement RAG
pass
case _:
stream_func: ContentStream = rag_chat_repo.inference(
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_id=session.id,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
case _: # default is chat robot
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_id=session.id,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
Expand Down Expand Up @@ -307,7 +313,7 @@ async def get_chathistory(
```
"""
current_user = await account_repo.read_account_by_username(username=jwt_payload.username)
if session_repo.verify_session_by_account_id(session_uuid=uuid, account_id=current_user.id) is False:
if await session_repo.verify_session_by_account_id(session_uuid=uuid, account_id=current_user.id) is False:
Aisuko marked this conversation as resolved.
Show resolved Hide resolved
raise http_404_exc_uuid_not_found_request(uuid=uuid)
session = await session_repo.read_sessions_by_uuid(session_uuid=uuid)
chats = await chat_repo.read_chat_history_by_session_id(id=session.id)
Expand Down
87 changes: 87 additions & 0 deletions backend/src/api/routes/rag_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fastapi
from src.models.schemas.dataset import RagDatasetCreate, RagDatasetResponse
from src.repository.rag_datasets_eng import DatasetEng

router = fastapi.APIRouter(prefix="/ds", tags=["datasets"])


@router.get(
path="/list",
name="datasets:get-dataset-list",
response_model=list[RagDatasetResponse],
status_code=fastapi.status.HTTP_200_OK,
)
async def get_dataset_list() -> list[RagDatasetResponse]:
"""
Waiting for implementing

"""
pass


@router.get(
path="/{name}",
name="datasets:get-dataset-by-name",
response_model=RagDatasetResponse,
status_code=fastapi.status.HTTP_200_OK,
)
async def get_dataset_by_name(name: str) -> RagDatasetResponse:
"""
Waiting for implementing
"""
pass


@router.post(
path="/load",
name="datasets:load-dataset",
response_model=RagDatasetResponse,
status_code=fastapi.status.HTTP_201_CREATED,
)
async def load_dataset(
rag_ds_create: RagDatasetCreate,
) -> RagDatasetResponse:
"""

Loading the specific dataset into the vector db

curl -X 'POST' \
'http://127.0.0.1:8000/api/ds/load' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"name": "aisuko/squad01",
"des": "string",
"ratio": 0
}'

Return:
{
"name": "aisuko/squad01",
"status": "Success"
}
"""

res: dict = DatasetEng.load_dataset(rag_ds_create.name)

if res.get("insert_count") > 0:
status = "Success"
else:
status = "Failed"

return RagDatasetResponse(name=rag_ds_create.name, status=status)
5 changes: 5 additions & 0 deletions backend/src/config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class BackendBaseSettings(BaseSettings):

# Configurations for language model
LANGUAGE_MODEL_NAME: str = decouple.config("LANGUAGE_MODEL_NAME", cast=str) # type: ignore
EMBEDDING_MODEL_NAME: str = decouple.config("EMBEDDING_MODEL_NAME", cast=str) # type: ignore

# Admin setting
ADMIN_USERNAME: str = decouple.config("ADMIN_USERNAME", cast=str) # type: ignore
Expand All @@ -110,6 +111,10 @@ class BackendBaseSettings(BaseSettings):
ETCD_QUOTA_BACKEND_BYTES: int = decouple.config("ETCD_QUOTA_BACKEND_BYTES", cast=int) # type: ignore
NUM_CPU_CORES: float = decouple.config("NUM_CPU_CORES", cast=float) # type: ignore

EMBEDDING_ENG: str = decouple.config("EMBEDDING_ENG", cast=str) # type: ignore
EMBEDDING_ENG_PORT: int = decouple.config("EMBEDDING_ENG_PORT", cast=int) # type: ignore
NUM_CPU_CORES_EMBEDDING: int = decouple.config("NUM_CPU_CORES_EMBEDDING", cast=str) # type: ignore

class Config(pydantic.ConfigDict):
case_sensitive: bool = True
env_file: str = f"{str(ROOT_DIR)}/.env"
Expand Down
1 change: 1 addition & 0 deletions backend/src/config/settings/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
MAX_SQL_LENGTH = 200

DEFAULT_COLLECTION = "default_collection"
# embedding dimension depending on model
DEFAULT_DIM = 384

# DEFAULT MODELS
Expand Down
30 changes: 30 additions & 0 deletions backend/src/models/schemas/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import datetime

from pydantic import Field
Expand All @@ -14,3 +30,17 @@ class DatasetResponse(BaseSchemaModel):
dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name")
created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")
updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time")


class RagDatasetCreate(BaseSchemaModel):
name: str = Field(..., title="DataSet Name", description="DataSet Name")
des: str | None = Field(..., title="Details", description="Details")
ratio: Optional[float] = Field(..., title="Ratio", description="Ratio")


class RagDatasetResponse(BaseSchemaModel):
name: str = Field(..., title="DataSet Name", description="DataSet Name")
# created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")
# updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time")
# ratio: Optional[float] = Field(..., title="Ratio", description="Ratio")
status: Optional[str] = Field(..., title="Status", description="Status")
15 changes: 15 additions & 0 deletions backend/src/repository/crud/vectors_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

7 changes: 7 additions & 0 deletions backend/src/repository/inference_eng.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class InferenceHelper:
infer_eng_url: pydantic.StrictStr = settings.INFERENCE_ENG
infer_eng_port: pydantic.PositiveInt = settings.INFERENCE_ENG_PORT
instruction: pydantic.StrictStr = settings.INSTRUCTION
embedding_url: pydantic.StrictStr = settings.EMBEDDING_ENG
embedding_port: pydantic.PositiveInt = settings.EMBEDDING_ENG_PORT

def init(self) -> None:
raise NotImplementedError("InferenceHelper is a singleton class. Use inference_helper instead.")
Expand Down Expand Up @@ -60,3 +62,8 @@ def instruct_infer_url(cls) -> str:
str: URL for the inference engine
"""
return f"http://{cls.infer_eng_url}:{cls.infer_eng_port}/completion"

@classmethod
def instruct_embedding_url(cls) -> str:
""" """
return f"http://{cls.embedding_url}:{cls.embedding_port}/embedding"
Loading