Skip to content

Commit

Permalink
Rename TextRerankRagModule to TextChunkRerankRagModule
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov committed Jul 2, 2024
1 parent d22d1bf commit ed166c3
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- RAG modules:
- Retrieval:
- `VectorStoreRetrievalRagModule` for retrieving text chunks from vector stores.
- `TextRerankRagModule` for re-ranking retrieved results.
- `TextChunksRerankRagModule` for re-ranking retrieved results.
- Response:
- `MetadataBeforeResponseRagModule` for appending metadata.
- `RulesetsBeforeResponseRagModule` for appending rulesets.
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/engines/rag-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ No modules implemented yet.

#### Retrieval
- `TextRetrievalRagModule` is for retrieving text chunks.
- `TextRerankRagModule` is for re-ranking retrieved results.
- `TextChunksRerankRagModule` is for re-ranking retrieved results.

#### Response
- `MetadataBeforeResponseRagModule` is for appending metadata.
Expand Down
4 changes: 2 additions & 2 deletions griptape/engines/rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .query.base_query_rag_module import BaseQueryRagModule
from .retrieval.base_retrieval_rag_module import BaseRetrievalRagModule
from .retrieval.base_rerank_rag_module import BaseRerankRagModule
from .retrieval.text_rerank_rag_module import TextRerankRagModule
from .retrieval.text_chunks_rerank_rag_module import TextChunksRerankRagModule
from .retrieval.vector_store_retrieval_rag_module import VectorStoreRetrievalRagModule
from .response.base_before_response_rag_module import BaseBeforeResponseRagModule
from .response.base_after_response_rag_module import BaseAfterResponseRagModule
Expand All @@ -17,7 +17,7 @@
"BaseQueryRagModule",
"BaseRetrievalRagModule",
"BaseRerankRagModule",
"TextRerankRagModule",
"TextChunksRerankRagModule",
"VectorStoreRetrievalRagModule",
"BaseBeforeResponseRagModule",
"BaseAfterResponseRagModule",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@


@define(kw_only=True)
class TextRerankRagModule(BaseRerankRagModule):
class TextChunksRerankRagModule(BaseRerankRagModule):
def run(self, context: RagContext) -> Sequence[BaseArtifact]:
return self.rerank_driver.run(context.query, context.text_chunks)
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import pytest
from griptape.drivers import CohereRerankDriver
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import TextRerankRagModule
from griptape.engines.rag.modules import TextChunksRerankRagModule


class TestTextRerankRagModule:
class TestTextChunksRerankRagModule:
@pytest.fixture
def mock_client(self, mocker):
mock_client = mocker.patch("cohere.Client").return_value
Expand All @@ -14,7 +14,7 @@ def mock_client(self, mocker):
return mock_client

def test_run(self, mock_client):
module = TextRerankRagModule(rerank_driver=CohereRerankDriver(api_key="api-key"))
module = TextChunksRerankRagModule(rerank_driver=CohereRerankDriver(api_key="api-key"))
result = module.run(RagContext(query="test"))

assert len(result) == 2

0 comments on commit ed166c3

Please sign in to comment.