Skip to content

Commit

Permalink
Implement multi-turn conversations (#92)
Browse files Browse the repository at this point in the history
This PR implements multi-turn conversations by using an OpenAI
compatible API to keep track of conversation state client side.

---------

Signed-off-by: Philipp Moritz <[email protected]>
  • Loading branch information
pcmoritz authored Jan 26, 2024
1 parent 1a9ce96 commit 1168cbd
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
40 changes: 26 additions & 14 deletions rag/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,17 @@ def prepare_response(chat_completion, stream):
return chat_completion.choices[0].message.content


def generate_response(
def send_request(
llm,
messages,
max_tokens=None,
temperature=0.0,
stream=False,
system_content="",
assistant_content="",
user_content="",
max_retries=1,
retry_interval=60,
):
"""Generate response from an LLM."""
retry_count = 0
client = get_client(llm=llm)
messages = [
{"role": role, "content": content}
for role, content in [
("system", system_content),
("assistant", assistant_content),
("user", user_content),
]
if content
]
while retry_count <= max_retries:
try:
chat_completion = client.chat.completions.create(
Expand All @@ -71,6 +59,30 @@ def generate_response(
return ""


def generate_response(
llm,
max_tokens=None,
temperature=0.0,
stream=False,
system_content="",
assistant_content="",
user_content="",
max_retries=1,
retry_interval=60,
):
"""Generate response from an LLM."""
messages = [
{"role": role, "content": content}
for role, content in [
("system", system_content),
("assistant", assistant_content),
("user", user_content),
]
if content
]
return send_request(llm, messages, max_tokens, temperature, stream, max_retries, retry_interval)


class QueryAgent:
def __init__(
self,
Expand Down
54 changes: 52 additions & 2 deletions rag/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import structlog
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pydantic import BaseModel, Field
from rank_bm25 import BM25Okapi
from ray import serve
from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
from starlette.responses import StreamingResponse

from rag.config import EMBEDDING_DIMENSIONS, MAX_CONTEXT_LENGTHS
from rag.generate import QueryAgent
from rag.generate import QueryAgent, send_request
from rag.index import load_index

app = FastAPI()
Expand Down Expand Up @@ -67,6 +67,15 @@ class Query(BaseModel):
query: str


class Message(BaseModel):
role: str = Field(..., description="The role of the author of the message, typically 'user', or 'assistant'.")
content: str = Field(..., description="The content of the message.")


class Request(BaseModel):
messages: List[Message] = Field(..., description="A list of messages that make up the conversation.")


class Answer(BaseModel):
question: str
answer: str
Expand Down Expand Up @@ -194,6 +203,7 @@ def query(self, query: Query) -> Answer:
result = self.predict(query, stream=False)
return Answer.parse_obj(result)

# This will be removed after all traffic is migrated to the /chat endpoint
def produce_streaming_answer(self, query, result):
answer = []
for answer_piece in result["answer"]:
Expand All @@ -213,13 +223,53 @@ def produce_streaming_answer(self, query, result):
answer="".join(answer),
)

# This will be removed after all traffic is migrated to the /chat endpoint
@app.post("/stream")
def stream(self, query: Query) -> StreamingResponse:
result = self.predict(query, stream=True)
return StreamingResponse(
self.produce_streaming_answer(query.query, result), media_type="text/plain"
)

def produce_chat_answer(self, request, result):
answer = []
for answer_piece in result["answer"]:
answer.append(answer_piece)
yield answer_piece

if result["sources"]:
yield "\n\n**Sources:**\n"
for source in result["sources"]:
yield "* " + source + "\n"

self.logger.info(
"finished chat query",
request=request.dict(),
document_ids=result["document_ids"],
llm=result["llm"],
answer="".join(answer),
)

@app.post("/chat")
def chat(self, request: Request) -> StreamingResponse:
if len(request.messages) == 1:
query = Query(query=request.messages[0].content)
result = self.predict(query, stream=True)
else:
# For now, we always use the OSS agent for follow up questions
agent = self.oss_agent
answer = send_request(
llm=agent.llm,
messages=request.messages,
max_tokens=agent.max_tokens,
temperature=agent.temperature,
stream=True)
result = {"answer": answer, "llm": agent.llm, "sources": [], "document_ids": []}

return StreamingResponse(
self.produce_chat_answer(request, result),
media_type="text/plain")


# Deploy the Ray Serve app
deployment = RayAssistantDeployment.bind(
Expand Down

0 comments on commit 1168cbd

Please sign in to comment.