-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* draft searchQnA example * skip the aio bug, test pass 1/2 * fix dep * add copyright * add README * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
9a2439e
commit 393f6f8
Showing
6 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Search Question and Answering | ||
|
||
Search Question and Answering is the task of using Search Engine (e.g. Google Search) to improve the QA quality. Large language models have limitation on answering real-time information or specific details because they are limited to prior training data. A search engine can make up this advantage. By using a search engine, this SearchQnA service will firstly look up the relevant source web pages and feed them as context to the LLMs, so LLMs can use those context to compose answers more precisely. | ||
|
||
## Start Service | ||
|
||
- Start the TGI service to deploy your LLM | ||
|
||
```sh | ||
cd serving/tgi_gaudi | ||
bash build_docker.sh | ||
bash launch_tgi_service.sh | ||
``` | ||
|
||
- Start the SearchQnA application using Google Search | ||
|
||
```sh | ||
cd /home/sdp/sihanche/GenAIExamples/SearchQnA/langchain/docker | ||
docker build . --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${http_proxy} -t intel/gen-ai-examples:searchqna-gaudi --no-cache | ||
docker run -e TGI_ENDPOINT=<TGI ENDPOINT> -e GOOGLE_CSE_ID=<GOOGLE CSE ID> -e GOOGLE_API_KEY=<GOOGLE API KEY> -e HUGGINGFACEHUB_API_TOKEN=<HUGGINGFACE API TOKEN> -p 8085:8000 -e http_proxy=$http_proxy -e https_proxy=$https_proxy -v $PWD/qna-app:/qna-app --runtime=habana -e HABANA_VISIBE_DEVILCES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host intel/gen-ai-examples:searchqna-gaudi | ||
``` | ||
|
||
- Test | ||
|
||
```sh | ||
curl http://localhost:8085/v1/rag/web_search_chat_stream -X POST -d '{"query":"Give me some latest news?"}' -H 'Content-Type: application/json' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# HABANA environment | ||
FROM vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1 AS hpu | ||
RUN rm -rf /etc/ssh/ssh_host* | ||
|
||
# Set environment variables | ||
ENV LANG=en_US.UTF-8 | ||
ENV PYTHONPATH=/home/user:/usr/lib/habanalabs/:/langchain/libs/community:/langchain/libs/langchain | ||
|
||
# Install required branch | ||
RUN git clone https://github.com/Spycsh/langchain.git /langchain -b master | ||
RUN cd /langchain/libs/langchain && \ | ||
pip install -e . && \ | ||
cd - | ||
|
||
RUN useradd -m -s /bin/bash user && \ | ||
mkdir -p /home/user && \ | ||
chown -R user /home/user/ | ||
|
||
USER user | ||
|
||
COPY requirements.txt /tmp/requirements.txt | ||
|
||
# Install dependency | ||
RUN pip install --no-cache-dir -U -r /tmp/requirements.txt | ||
|
||
# work dir should contains the server | ||
# make sure it can be edited by user | ||
WORKDIR /home/user/qna-app | ||
COPY qna-app /home/user/qna-app | ||
|
||
ENTRYPOINT ["python", "server.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import sys | ||
from queue import Queue | ||
from threading import Thread | ||
|
||
from fastapi import APIRouter, FastAPI, Request | ||
from fastapi.responses import StreamingResponse | ||
from langchain.callbacks.base import BaseCallbackHandler | ||
from langchain.chains import RetrievalQAWithSourcesChain | ||
from langchain.retrievers.web_research import WebResearchRetriever | ||
from langchain_community.embeddings import HuggingFaceInstructEmbeddings | ||
from langchain_community.llms import HuggingFaceEndpoint | ||
from langchain_community.utilities import GoogleSearchAPIWrapper | ||
from langchain_community.vectorstores import Chroma | ||
from starlette.middleware.cors import CORSMiddleware | ||
|
||
app = FastAPI() | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
|
||
class QueueCallbackHandler(BaseCallbackHandler): | ||
"""A queue that holds the result answer token buffer for streaming response.""" | ||
|
||
def __init__(self, queue: Queue): | ||
self.queue = queue | ||
self.enter_answer_phase = False | ||
|
||
def on_llm_new_token(self, token: str, **kwargs): | ||
sys.stdout.write(token) | ||
sys.stdout.flush() | ||
if self.enter_answer_phase: | ||
self.queue.put( | ||
{ | ||
"answer": token, | ||
} | ||
) | ||
|
||
def on_llm_end(self, *args, **kwargs): | ||
self.enter_answer_phase = not self.enter_answer_phase | ||
return True | ||
|
||
|
||
class SearchQuestionAnsweringAPIRouter(APIRouter): | ||
"""The router for SearchQnA example. | ||
The input request will firstly go through Google Search, and the fetched HTML will be stored in the vector db. | ||
Then the input request together with relevant retrieved documents will be forward to the LLM to get the answers. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
entrypoint: str, | ||
vectordb_embedding_model: str = "hkunlp/instructor-large", | ||
vectordb_persistent_directory: str = "/home/user/chroma_db_oai", | ||
) -> None: | ||
super().__init__() | ||
self.entrypoint = entrypoint | ||
self.queue = Queue() # For streaming output tokens | ||
|
||
# setup TGI endpoint | ||
self.llm = HuggingFaceEndpoint( | ||
endpoint_url=entrypoint, | ||
max_new_tokens=1024, | ||
top_k=10, | ||
top_p=0.95, | ||
typical_p=0.95, | ||
temperature=0.01, | ||
repetition_penalty=1.03, | ||
streaming=True, | ||
callbacks=[QueueCallbackHandler(queue=self.queue)], | ||
) | ||
|
||
# check google api key is provided | ||
if "GOOGLE_API_KEY" not in os.environ or "GOOGLE_API_KEY" not in os.environ: | ||
raise Exception("Please make sure to set GOOGLE_API_KEY and GOOGLE_API_KEY environment variables!") | ||
|
||
# Notice: please check or manually delete the vectordb directory if you do not previous histories | ||
self.vectorstore = Chroma( | ||
embedding_function=HuggingFaceInstructEmbeddings(model_name=vectordb_embedding_model), | ||
persist_directory=vectordb_persistent_directory, | ||
) | ||
|
||
# Build up the google search service | ||
self.search = GoogleSearchAPIWrapper() | ||
|
||
# Compose the websearch retriever | ||
self.web_search_retriever = WebResearchRetriever.from_llm( | ||
vectorstore=self.vectorstore, llm=self.llm, search=self.search | ||
) | ||
|
||
# Compose the whole chain | ||
self.llm_chain = RetrievalQAWithSourcesChain.from_chain_type( | ||
self.llm, | ||
retriever=self.web_search_retriever, | ||
) | ||
|
||
def handle_search_chat(self, query: str): | ||
response = self.llm_chain({"question": query}) | ||
return response["answer"], response["sources"] | ||
|
||
|
||
tgi_endpoint = os.getenv("TGI_ENDPOINT", "http://localhost:8080") | ||
|
||
router = SearchQuestionAnsweringAPIRouter( | ||
entrypoint=tgi_endpoint, | ||
) | ||
|
||
|
||
@router.post("/v1/rag/web_search_chat") | ||
async def web_search_chat(request: Request): | ||
params = await request.json() | ||
print(f"[websearch - chat] POST request: /v1/rag/web_search_chat, params:{params}") | ||
query = params["query"] | ||
answer, sources = router.handle_search_chat(query=query) | ||
print(f"[websearch - chat] answer: {answer}, sources: {sources}") | ||
return {"answer": answer, "sources": sources} | ||
|
||
|
||
@router.post("/v1/rag/web_search_chat_stream") | ||
async def web_search_chat_stream(request: Request): | ||
params = await request.json() | ||
print(tgi_endpoint) | ||
print(f"[websearch - streaming chat] POST request: /v1/rag/web_search_chat_stream, params:{params}") | ||
query = params["query"] | ||
|
||
def stream_callback(query): | ||
finished = object() | ||
|
||
def task(): | ||
_ = router.llm_chain({"question": query}) | ||
router.queue.put(finished) | ||
|
||
t = Thread(target=task) | ||
t.start() | ||
while True: | ||
try: | ||
item = router.queue.get() | ||
if item is finished: | ||
break | ||
yield item | ||
except Queue.Empty: | ||
continue | ||
|
||
def stream_generator(): | ||
chat_response = "" | ||
# FIXME need to add the sources and chat_history | ||
for res_dict in stream_callback({"question": query}): | ||
text = res_dict["answer"] | ||
chat_response += text | ||
if text == " ": | ||
yield "data: @#$\n\n" | ||
continue | ||
if text.isspace(): | ||
continue | ||
if "\n" in text: | ||
yield "data: <br/>\n\n" | ||
new_text = text.replace(" ", "@#$") | ||
yield f"data: {new_text}\n\n" | ||
chat_response = chat_response.split("</s>")[0] | ||
print(f"[rag - chat_stream] stream response: {chat_response}") | ||
yield "data: [DONE]\n\n" | ||
|
||
return StreamingResponse(stream_generator(), media_type="text/event-stream") | ||
|
||
|
||
app.include_router(router) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
|
||
fastapi_port = os.getenv("FASTAPI_PORT", "8000") | ||
uvicorn.run(app, host="0.0.0.0", port=int(fastapi_port)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
beautifulsoup4 | ||
chromadb | ||
eager | ||
fastapi | ||
google-api-python-client>=2.100.0 | ||
html2text | ||
InstructorEmbedding | ||
optimum[habana] | ||
sentence-transformers==2.2.2 | ||
uvicorn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
#!/bin/bash | ||
|
||
git clone https://github.com/huggingface/tgi-gaudi.git | ||
cd ./tgi-gaudi/ | ||
docker build -t ghcr.io/huggingface/tgi-gaudi:1.2.1 . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
#!/bin/bash | ||
|
||
# Set default values | ||
default_port=8080 | ||
default_model="Intel/neural-chat-7b-v3-3" | ||
default_num_cards=1 | ||
|
||
# Check if all required arguments are provided | ||
if [ "$#" -lt 0 ] || [ "$#" -gt 3 ]; then | ||
echo "Usage: $0 [num_cards] [port_number] [model_name]" | ||
exit 1 | ||
fi | ||
|
||
# Assign arguments to variables | ||
num_cards=${1:-$default_num_cards} | ||
port_number=${2:-$default_port} | ||
model_name=${3:-$default_model} | ||
|
||
# Check if num_cards is within the valid range (1-8) | ||
if [ "$num_cards" -lt 1 ] || [ "$num_cards" -gt 8 ]; then | ||
echo "Error: num_cards must be between 1 and 8." | ||
exit 1 | ||
fi | ||
|
||
# Set the volume variable | ||
volume=$PWD/data | ||
|
||
# Build the Docker run command based on the number of cards | ||
if [ "$num_cards" -eq 1 ]; then | ||
docker_cmd="docker run -p $port_number:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model_name" | ||
else | ||
docker_cmd="docker run -p $port_number:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model_name --sharded true --num-shard $num_cards" | ||
fi | ||
|
||
# Execute the Docker run command | ||
eval $docker_cmd |