Skip to content

Commit

Permalink
feat(ask): non-streaming now calls streaming (#3409)
Browse files Browse the repository at this point in the history
# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
  • Loading branch information
StanGirard authored Oct 21, 2024
1 parent 2b347c9 commit e71e46b
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 815 deletions.
43 changes: 19 additions & 24 deletions core/quivr_core/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,34 +545,28 @@ def ask(
print(answer.answer)
```
"""
llm = self.llm

# If you passed a different llm model we'll override the brain one
if retrieval_config:
if retrieval_config.llm_config != self.llm.get_config():
llm = LLMEndpoint.from_config(config=retrieval_config.llm_config)
else:
retrieval_config = RetrievalConfig(llm_config=self.llm.get_config())

if rag_pipeline is None:
rag_pipeline = QuivrQARAGLangGraph

rag_instance = rag_pipeline(
retrieval_config=retrieval_config, llm=llm, vector_store=self.vector_db
)
async def collect_streamed_response():
full_answer = ""
async for response in self.ask_streaming(
question=question,
retrieval_config=retrieval_config,
rag_pipeline=rag_pipeline,
list_files=list_files,
chat_history=chat_history
):
full_answer += response.answer
return full_answer

# Run the async function in the event loop
loop = asyncio.get_event_loop()
full_answer = loop.run_until_complete(collect_streamed_response())

chat_history = self.default_chat if chat_history is None else chat_history
list_files = [] if list_files is None else list_files

parsed_response = rag_instance.answer(
question=question, history=chat_history, list_files=list_files
)

chat_history.append(HumanMessage(content=question))
chat_history.append(AIMessage(content=parsed_response.answer))
chat_history.append(AIMessage(content=full_answer))

# Save answer to the chat history
return parsed_response
# Return the final response
return ParsedRAGResponse(answer=full_answer)

async def ask_streaming(
self,
Expand Down Expand Up @@ -635,3 +629,4 @@ async def ask_streaming(
chat_history.append(HumanMessage(content=question))
chat_history.append(AIMessage(content=full_answer))
yield response

4 changes: 2 additions & 2 deletions core/quivr_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class LLMEndpointConfig(QuivrBaseConfig):
Attributes:
supplier (DefaultModelSuppliers): The LLM provider (default: OPENAI).
model (str): The specific model to use (default: "gpt-3.5-turbo-0125").
model (str): The specific model to use (default: "gpt-4o").
context_length (int | None): The maximum context length for the model.
tokenizer_hub (str | None): The tokenizer to use for this model.
llm_base_url (str | None): Base URL for the LLM API.
Expand All @@ -247,7 +247,7 @@ class LLMEndpointConfig(QuivrBaseConfig):
"""

supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
model: str = "gpt-3.5-turbo-0125"
model: str = "gpt-4o"
context_length: int | None = None
tokenizer_hub: str | None = None
llm_base_url: str | None = None
Expand Down
2 changes: 1 addition & 1 deletion core/quivr_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def model_supports_function_calling(model_name: str):
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0613",
"gpt-3.5-turbo-0125",
"gpt-4o",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
"gpt-4-0125-preview",
Expand Down
2 changes: 1 addition & 1 deletion core/tests/rag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ retrieval_config:
supplier: "openai"

# The model to use for the LLM for the given supplier
model: "gpt-3.5-turbo-0125"
model: "gpt-4o"

max_input_tokens: 2000

Expand Down
2 changes: 1 addition & 1 deletion core/tests/rag_config_workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ retrieval_config:
supplier: "openai"

# The model to use for the LLM for the given supplier
model: "gpt-3.5-turbo-0125"
model: "gpt-4o"

max_input_tokens: 2000

Expand Down
2 changes: 1 addition & 1 deletion core/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def test_default_llm_config():
config = LLMEndpointConfig()

assert config.model_dump(exclude={"llm_api_key"}) == LLMEndpointConfig(
model="gpt-3.5-turbo-0125",
model="gpt-4o",
llm_base_url=None,
llm_api_key=None,
max_input_tokens=2000,
Expand Down
Loading

0 comments on commit e71e46b

Please sign in to comment.