Skip to content

Commit

Permalink
Extract embedding model to TEI endpoint (#26)
Browse files Browse the repository at this point in the history
* support tei for embedding serving

Signed-off-by: LetongHan <[email protected]>

* update codes

Signed-off-by: LetongHan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: LetongHan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
letonghan and pre-commit-ci[bot] authored Mar 28, 2024
1 parent 38e8945 commit 9a2439e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions ChatQnA/langchain/docker/qna-app/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fastapi import APIRouter, FastAPI, File, Request, UploadFile
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
from guardrails import moderation_prompt_for_chat, unsafe_dict
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import Redis
from langchain_core.messages import HumanMessage
Expand All @@ -47,7 +47,7 @@

class RAGAPIRouter(APIRouter):

def __init__(self, upload_dir, entrypoint, safety_guard_endpoint) -> None:
def __init__(self, upload_dir, entrypoint, safety_guard_endpoint, tei_endpoint=None) -> None:
super().__init__()
self.upload_dir = upload_dir
self.entrypoint = entrypoint
Expand Down Expand Up @@ -81,7 +81,13 @@ def __init__(self, upload_dir, entrypoint, safety_guard_endpoint) -> None:
print("[rag - router] LLM initialized.")

# Define LLM Chain
self.embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)
if tei_endpoint:
# create embeddings using TEI endpoint service
self.embeddings = HuggingFaceHubEmbeddings(model=tei_endpoint)
else:
# create embeddings using local embedding model
self.embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)

rds = Redis.from_existing_index(
self.embeddings,
index_name=INDEX_NAME,
Expand Down Expand Up @@ -130,7 +136,8 @@ def handle_rag_chat(self, query: str):
upload_dir = os.getenv("RAG_UPLOAD_DIR", "./upload_dir")
tgi_llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT")
router = RAGAPIRouter(upload_dir, tgi_llm_endpoint, safety_guard_endpoint)
tei_embedding_endpoint = os.getenv("TEI_ENDPOINT")
router = RAGAPIRouter(upload_dir, tgi_llm_endpoint, safety_guard_endpoint, tei_embedding_endpoint)


@router.post("/v1/rag/chat")
Expand Down

0 comments on commit 9a2439e

Please sign in to comment.