diff --git a/ChatQnA/langchain/docker/qna-app/app/server.py b/ChatQnA/langchain/docker/qna-app/app/server.py index c34daaa7..a50ac6ba 100644 --- a/ChatQnA/langchain/docker/qna-app/app/server.py +++ b/ChatQnA/langchain/docker/qna-app/app/server.py @@ -20,7 +20,7 @@ from fastapi import APIRouter, FastAPI, File, Request, UploadFile from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from guardrails import moderation_prompt_for_chat, unsafe_dict -from langchain_community.embeddings import HuggingFaceBgeEmbeddings +from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings from langchain_community.llms import HuggingFaceEndpoint from langchain_community.vectorstores import Redis from langchain_core.messages import HumanMessage @@ -47,7 +47,7 @@ class RAGAPIRouter(APIRouter): - def __init__(self, upload_dir, entrypoint, safety_guard_endpoint) -> None: + def __init__(self, upload_dir, entrypoint, safety_guard_endpoint, tei_endpoint=None) -> None: super().__init__() self.upload_dir = upload_dir self.entrypoint = entrypoint @@ -81,7 +81,13 @@ def __init__(self, upload_dir, entrypoint, safety_guard_endpoint) -> None: print("[rag - router] LLM initialized.") # Define LLM Chain - self.embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL) + if tei_endpoint: + # create embeddings using TEI endpoint service + self.embeddings = HuggingFaceHubEmbeddings(model=tei_endpoint) + else: + # create embeddings using local embedding model + self.embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL) + rds = Redis.from_existing_index( self.embeddings, index_name=INDEX_NAME, @@ -130,7 +136,8 @@ def handle_rag_chat(self, query: str): upload_dir = os.getenv("RAG_UPLOAD_DIR", "./upload_dir") tgi_llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT") -router = RAGAPIRouter(upload_dir, tgi_llm_endpoint, safety_guard_endpoint) +tei_embedding_endpoint = os.getenv("TEI_ENDPOINT") +router = RAGAPIRouter(upload_dir, tgi_llm_endpoint, safety_guard_endpoint, tei_embedding_endpoint) @router.post("/v1/rag/chat")