-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add RAG pipeline #6461
Merged
Merged
feat: Add RAG pipeline #6461
Changes from 11 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
fdad616
add rag pipeline
ZanSara a4dec8b
pipeline_utils
ZanSara ef01154
typo
ZanSara c69c7b6
reno
ZanSara 5985f93
Merge branch 'main' into rag_pipeline
ZanSara a1e5db7
Merge branch 'main' into rag_pipeline
vblagoje 6ce1f2d
Pydoc spelling
vblagoje 77298ea
Add getting started example, use given document_store reference
vblagoje ff182de
Small fixes
vblagoje 18cbccf
Make setting OPENAI_API_KEY obvious
vblagoje a244094
Simpler api key setup
vblagoje 2fa421e
Update examples/getting_started/rag.py
vblagoje File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
from haystack import Document | ||
from haystack.document_stores import InMemoryDocumentStore | ||
from haystack.pipeline_utils import build_rag_pipeline | ||
|
||
API_KEY = None # SET YOUR OPENAI API KEY HERE | ||
|
||
# We support many different databases. Here we load a simple and lightweight in-memory document store. | ||
document_store = InMemoryDocumentStore() | ||
|
||
# Create some example documents and add them to the document store. | ||
documents = [ | ||
Document(content="My name is Jean and I live in Paris."), | ||
Document(content="My name is Mark and I live in Berlin."), | ||
Document(content="My name is Giorgio and I live in Rome."), | ||
] | ||
document_store.write_documents(documents) | ||
|
||
# Let's now build a simple RAG pipeline that uses a generative model to answer questions. | ||
rag_pipeline = build_rag_pipeline(llm_api_key=API_KEY, document_store=document_store) | ||
answers = rag_pipeline.run(query="Who lives in Rome?") | ||
print(answers.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from haystack.pipeline_utils.rag import build_rag_pipeline | ||
|
||
__all__ = ["build_rag_pipeline"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from typing import Optional | ||
|
||
from haystack import Pipeline | ||
from haystack.dataclasses import Answer | ||
from haystack.document_stores import InMemoryDocumentStore | ||
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever | ||
from haystack.components.embedders import SentenceTransformersTextEmbedder | ||
from haystack.components.generators import GPTGenerator | ||
from haystack.components.builders.answer_builder import AnswerBuilder | ||
from haystack.components.builders.prompt_builder import PromptBuilder | ||
|
||
|
||
def build_rag_pipeline( | ||
document_store: "InMemoryDocumentStore", | ||
generation_model: str = "gpt-3.5-turbo", | ||
prompt_template: Optional[str] = None, | ||
embedding_model: Optional[str] = None, | ||
llm_api_key: Optional[str] = None, | ||
): | ||
""" | ||
Returns a prebuilt pipeline to perform retrieval augmented generation with or without an embedding model | ||
(without embeddings, it performs retrieval using BM25). | ||
|
||
Example usage: | ||
|
||
```python | ||
from haystack.utils import build_rag_pipeline | ||
pipeline = build_rag_pipeline(document_store=your_document_store_instance) | ||
pipeline.run(query="What's the capital of France?") | ||
|
||
>>> Answer(data="The capital of France is Paris.") | ||
``` | ||
|
||
:param document_store: An instance of a DocumentStore to read from. | ||
:param generation_model: The name of the model to use for generation. | ||
:param prompt_template: The template to use for the prompt. If not given, a default template is used. | ||
:param embedding_model: The name of the model to use for embedding. If not given, BM25 is used. | ||
:param llm_api_key: The API key to use for the OpenAI Language Model. If not given, the value of the | ||
llm_api_key will be attempted to be read from the environment variable OPENAI_API_KEY. | ||
""" | ||
return _RAGPipeline( | ||
document_store=document_store, | ||
generation_model=generation_model, | ||
prompt_template=prompt_template, | ||
embedding_model=embedding_model, | ||
llm_api_key=llm_api_key, | ||
) | ||
|
||
|
||
class _RAGPipeline: | ||
""" | ||
A simple ready-made pipeline for RAG. It requires a populated document store. | ||
|
||
If an embedding model is given, it uses embedding retrieval. Otherwise, it falls back to BM25 retrieval. | ||
|
||
Example usage: | ||
|
||
```python | ||
rag_pipe = RAGPipeline(document_store=InMemoryDocumentStore()) | ||
answers = rag_pipe.run(query="Who lives in Rome?") | ||
>>> Answer(data="Giorgio") | ||
``` | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
document_store: InMemoryDocumentStore, | ||
generation_model: str = "gpt-3.5-turbo", | ||
prompt_template: Optional[str] = None, | ||
embedding_model: Optional[str] = None, | ||
llm_api_key: Optional[str] = None, | ||
): | ||
""" | ||
:param document_store: An instance of a DocumentStore to retrieve documents from. | ||
:param generation_model: The name of the model to use for generation. | ||
:param prompt_template: The template to use for the prompt. If not given, a default template is used. | ||
:param embedding_model: The name of the model to use for embedding. If not given, BM25 is used. | ||
:param llm_api_key: The API key to use for the OpenAI Language Model. | ||
""" | ||
prompt_template = ( | ||
prompt_template | ||
or """ | ||
Given these documents, answer the question. | ||
|
||
Documents: | ||
{% for doc in documents %} | ||
{{ doc.content }} | ||
{% endfor %} | ||
|
||
Question: {{question}} | ||
|
||
Answer: | ||
""" | ||
) | ||
if not isinstance(document_store, InMemoryDocumentStore): | ||
raise ValueError("RAGPipeline only works with an InMemoryDocumentStore.") | ||
|
||
self.pipeline = Pipeline() | ||
|
||
if embedding_model: | ||
self.pipeline.add_component( | ||
instance=SentenceTransformersTextEmbedder(model_name_or_path=embedding_model), name="text_embedder" | ||
) | ||
self.pipeline.add_component( | ||
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="retriever" | ||
) | ||
self.pipeline.connect("text_embedder", "retriever") | ||
else: | ||
self.pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") | ||
|
||
self.pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") | ||
self.pipeline.add_component(instance=GPTGenerator(api_key=llm_api_key, model_name=generation_model), name="llm") | ||
self.pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") | ||
self.pipeline.connect("retriever", "prompt_builder.documents") | ||
self.pipeline.connect("prompt_builder.prompt", "llm.prompt") | ||
self.pipeline.connect("llm.replies", "answer_builder.replies") | ||
self.pipeline.connect("llm.metadata", "answer_builder.metadata") | ||
self.pipeline.connect("retriever", "answer_builder.documents") | ||
|
||
def run(self, query: str) -> Answer: | ||
""" | ||
Performs RAG using the given query. | ||
|
||
:param query: The query to ask. | ||
:return: An Answer object. | ||
""" | ||
run_values = {"prompt_builder": {"question": query}, "answer_builder": {"query": query}} | ||
if self.pipeline.graph.nodes.get("text_embedder"): | ||
run_values["text_embedder"] = {"text": query} | ||
else: | ||
run_values["retriever"] = {"query": query} | ||
|
||
return self.pipeline.run(run_values)["answer_builder"]["answers"][0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
features: | ||
- | | ||
Add a `build_rag_pipeline` utility function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from unittest.mock import patch, Mock | ||
import pytest | ||
|
||
from haystack.dataclasses import Answer | ||
from haystack.testing.factory import document_store_class | ||
from haystack.document_stores import InMemoryDocumentStore | ||
from haystack.pipeline_utils.rag import build_rag_pipeline | ||
|
||
|
||
@pytest.fixture | ||
def mock_chat_completion(): | ||
""" | ||
Mock the OpenAI API completion response and reuse it for tests | ||
""" | ||
with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create: | ||
# mimic the response from the OpenAI API | ||
mock_choice = Mock() | ||
mock_choice.index = 0 | ||
mock_choice.finish_reason = "stop" | ||
|
||
mock_message = Mock() | ||
mock_message.content = "I'm fine, thanks. How are you?" | ||
mock_message.role = "user" | ||
|
||
mock_choice.message = mock_message | ||
|
||
mock_response = Mock() | ||
mock_response.model = "gpt-3.5-turbo" | ||
mock_response.usage = Mock() | ||
mock_response.usage.items.return_value = [ | ||
("prompt_tokens", 57), | ||
("completion_tokens", 40), | ||
("total_tokens", 97), | ||
] | ||
mock_response.choices = [mock_choice] | ||
mock_chat_completion_create.return_value = mock_response | ||
yield mock_chat_completion_create | ||
|
||
|
||
def test_rag_pipeline(mock_chat_completion): | ||
rag_pipe = build_rag_pipeline(document_store=InMemoryDocumentStore()) | ||
answer = rag_pipe.run(query="question") | ||
assert isinstance(answer, Answer) | ||
|
||
|
||
def test_rag_pipeline_other_docstore(): | ||
FakeStore = document_store_class("FakeStore") | ||
with pytest.raises(ValueError, match="InMemoryDocumentStore"): | ||
assert build_rag_pipeline(document_store=FakeStore()) | ||
|
||
|
||
def test_rag_pipeline_no_embedder_if_no_model(): | ||
rag_pipe = build_rag_pipeline(document_store=InMemoryDocumentStore()) | ||
assert "text_embedder" not in rag_pipe.pipeline.graph.nodes | ||
|
||
|
||
def test_rag_pipeline_embedder_exist_if_model_is_given(): | ||
rag_pipe = build_rag_pipeline( | ||
document_store=InMemoryDocumentStore(), embedding_model="sentence-transformers/all-mpnet-base-v2" | ||
) | ||
assert "text_embedder" in rag_pipe.pipeline.graph.nodes |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I was thinking if they have OPENAI_API_KEY set by some chance this will work out of the box. But yes, I'll go with what you guys think is better