diff --git a/griptape/engines/rag/modules/__init__.py b/griptape/engines/rag/modules/__init__.py index eb514331d..cb11551e3 100644 --- a/griptape/engines/rag/modules/__init__.py +++ b/griptape/engines/rag/modules/__init__.py @@ -25,5 +25,5 @@ "PromptResponseRagModule", "RulesetsBeforeResponseRagModule", "MetadataBeforeResponseRagModule", - "TextChunksResponseRagModule" + "TextChunksResponseRagModule", ] diff --git a/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py b/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py index 1a80fecbb..8a5f2cab4 100644 --- a/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py +++ b/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py @@ -14,8 +14,6 @@ def run(self, context: RagContext) -> RagContext: metadata = self.metadata if context_metadata is None else context_metadata if metadata is not None: - context.before_query.append( - J2("engines/rag/modules/response/metadata/system.j2").render(metadata=metadata) - ) + context.before_query.append(J2("engines/rag/modules/response/metadata/system.j2").render(metadata=metadata)) return context diff --git a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py index 5ac5b923c..699aa1124 100644 --- a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py +++ b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py @@ -6,7 +6,6 @@ @define(kw_only=True) class TextChunksResponseRagModule(BaseResponseRagModule): - def run(self, context: RagContext) -> RagContext: context.output = ListArtifact(context.text_chunks) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 82152e3f5..9c94f9347 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence, Any, Callable from attrs import define, field, Factory -from griptape.artifacts import TextArtifact, BaseArtifact +from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseRetrievalRagModule @@ -13,7 +13,7 @@ class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): vector_store_driver: BaseVectorStoreDriver = field() query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field( + process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]) ) diff --git a/griptape/engines/rag/rag_context.py b/griptape/engines/rag/rag_context.py index 211016e68..ce0d1adc6 100644 --- a/griptape/engines/rag/rag_context.py +++ b/griptape/engines/rag/rag_context.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Optional from attrs import define, field from griptape.artifacts import TextArtifact, BaseArtifact diff --git a/griptape/engines/rag/stages/base_rag_stage.py b/griptape/engines/rag/stages/base_rag_stage.py index bf1f5aa2e..4db996e30 100644 --- a/griptape/engines/rag/stages/base_rag_stage.py +++ b/griptape/engines/rag/stages/base_rag_stage.py @@ -17,5 +17,4 @@ def run(self, context: RagContext) -> RagContext: ... @property @abstractmethod - def modules(self) -> list[BaseRagModule]: - ... + def modules(self) -> list[BaseRagModule]: ... diff --git a/griptape/engines/rag/stages/query_rag_stage.py b/griptape/engines/rag/stages/query_rag_stage.py index 9c45822fe..84c200bc8 100644 --- a/griptape/engines/rag/stages/query_rag_stage.py +++ b/griptape/engines/rag/stages/query_rag_stage.py @@ -12,14 +12,12 @@ class QueryRagStage(BaseRagStage): @property def modules(self) -> list[BaseRagModule]: - return self.query_modules + return self.query_modules # pyright: ignore def run(self, context: RagContext) -> RagContext: logging.info(f"QueryStage: running {len(self.query_modules)} query generation modules in parallel") with self.futures_executor_fn() as executor: - utils.execute_futures_list( - [executor.submit(r.run, context) for r in self.query_modules] - ) + utils.execute_futures_list([executor.submit(r.run, context) for r in self.query_modules]) return context diff --git a/griptape/engines/rag/stages/response_rag_stage.py b/griptape/engines/rag/stages/response_rag_stage.py index c84d52b8c..aa0a15dfc 100644 --- a/griptape/engines/rag/stages/response_rag_stage.py +++ b/griptape/engines/rag/stages/response_rag_stage.py @@ -4,7 +4,8 @@ from griptape.engines.rag.modules import ( BaseResponseRagModule, BaseBeforeResponseRagModule, - BaseAfterResponseRagModule, BaseRagModule, + BaseAfterResponseRagModule, + BaseRagModule, ) from griptape.engines.rag.stages import BaseRagStage diff --git a/griptape/memory/task/storage/base_artifact_storage.py b/griptape/memory/task/storage/base_artifact_storage.py index fbd226363..e8378fee4 100644 --- a/griptape/memory/task/storage/base_artifact_storage.py +++ b/griptape/memory/task/storage/base_artifact_storage.py @@ -20,4 +20,4 @@ def can_store(self, artifact: BaseArtifact) -> bool: ... def summarize(self, namespace: str) -> TextArtifact | InfoArtifact: ... @abstractmethod - def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifact | InfoArtifact: ... + def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: ... diff --git a/griptape/memory/task/storage/blob_artifact_storage.py b/griptape/memory/task/storage/blob_artifact_storage.py index 79b5798df..9d09e17fa 100644 --- a/griptape/memory/task/storage/blob_artifact_storage.py +++ b/griptape/memory/task/storage/blob_artifact_storage.py @@ -26,5 +26,5 @@ def load_artifacts(self, namespace: str) -> ListArtifact: def summarize(self, namespace: str) -> InfoArtifact: return InfoArtifact("can't summarize artifacts") - def query(self, namespace: str, query: str, metadata: Any = None) -> InfoArtifact: + def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: return InfoArtifact("can't query artifacts") diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 0ccc2bdfa..e7df5555e 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -42,7 +42,7 @@ def summarize(self, namespace: str) -> TextArtifact: return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace)) - def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifact | InfoArtifact: + def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: if self.rag_engine is None: raise ValueError("RAG engine is not set.") @@ -50,13 +50,13 @@ def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifac RagContext( query=query, module_params={ - self.retrieval_rag_module_name: { + self.retrieval_rag_module_name: { # pyright: ignore "query_params": { "namespace": namespace, - "metadata": None if metadata is None else str(metadata) + "metadata": None if metadata is None else str(metadata), } } - } + }, ) ).output diff --git a/griptape/memory/task/task_memory.py b/griptape/memory/task/task_memory.py index 2f1fdbe16..df9c9be30 100644 --- a/griptape/memory/task/task_memory.py +++ b/griptape/memory/task/task_memory.py @@ -124,7 +124,7 @@ def summarize_namespace(self, namespace: str) -> TextArtifact | InfoArtifact: else: return InfoArtifact("Can't find memory content") - def query_namespace(self, namespace: str, query: str) -> TextArtifact | InfoArtifact: + def query_namespace(self, namespace: str, query: str) -> BaseArtifact: storage = self.namespace_storage.get(namespace) if storage: diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index d50e881c1..e85c8f28a 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -1,6 +1,6 @@ from __future__ import annotations from attrs import define, field -from griptape.artifacts import TextArtifact, ErrorArtifact +from griptape.artifacts import ErrorArtifact, BaseArtifact from griptape.engines.rag import RagEngine from griptape.tasks import BaseTextInputTask @@ -22,7 +22,7 @@ def rag_engine(self) -> RagEngine: def rag_engine(self, value: RagEngine) -> None: self._rag_engine = value - def run(self) -> TextArtifact | ErrorArtifact: + def run(self) -> BaseArtifact: result = self.rag_engine.process_query(self.input.to_text()).output if result is None: diff --git a/griptape/tools/task_memory_client/tool.py b/griptape/tools/task_memory_client/tool.py index b60c0acc0..ce89da22e 100644 --- a/griptape/tools/task_memory_client/tool.py +++ b/griptape/tools/task_memory_client/tool.py @@ -1,7 +1,7 @@ from __future__ import annotations from attrs import define from schema import Schema, Literal -from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact +from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact, BaseArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -39,7 +39,7 @@ def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact ), } ) - def query(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact: + def query(self, params: dict) -> BaseArtifact: memory = self.find_input_memory(params["values"]["memory_name"]) artifact_namespace = params["values"]["artifact_namespace"] query = params["values"]["query"] diff --git a/tests/mocks/mock_rag_module.py b/tests/mocks/mock_rag_module.py index c2c5e9d59..98b69cee8 100644 --- a/tests/mocks/mock_rag_module.py +++ b/tests/mocks/mock_rag_module.py @@ -1,5 +1,4 @@ from griptape.engines.rag.modules import BaseRagModule -class MockRagModule(BaseRagModule): - ... +class MockRagModule(BaseRagModule): ... diff --git a/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py index 12ff8321b..43927bd8a 100644 --- a/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py @@ -6,19 +6,14 @@ class TestMetadataBeforeResponseRagModule: def test_run(self): module = MetadataBeforeResponseRagModule(name="foo") - assert "foo" in module.run( - RagContext( - module_params={"foo": {"metadata": "foo"}}, - query="test" - ) - ).before_query[0] + assert "foo" in module.run(RagContext(module_params={"foo": {"metadata": "foo"}}, query="test")).before_query[0] def test_run_with_override(self): module = MetadataBeforeResponseRagModule(name="foo", metadata="bar") - assert "bar" in module.run( - RagContext( - module_params={"foo": {"query_params": {"metadata": "foo"}}}, - query="test" - ) - ).before_query[0] + assert ( + "bar" + in module.run( + RagContext(module_params={"foo": {"query_params": {"metadata": "foo"}}}, query="test") + ).before_query[0] + ) diff --git a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py index bade567ca..3a01b3eaa 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py @@ -8,18 +8,12 @@ class TestVectorStoreRetrievalRagModule: def test_run_without_namespace(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - module = VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver - ) + module = VectorStoreRetrievalRagModule(vector_store_driver=vector_store_driver) vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test") vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test") - result = module.run( - RagContext( - query="test", - ) - ) + result = module.run(RagContext(query="test")) assert len(result) == 2 assert result[0].value == "foobar1" @@ -28,18 +22,13 @@ def test_run_without_namespace(self): def test_run_with_namespace(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) module = VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver, - query_params={"namespace": "test"} + vector_store_driver=vector_store_driver, query_params={"namespace": "test"} ) vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test") vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test") - result = module.run( - RagContext( - query="test", - ) - ) + result = module.run(RagContext(query="test")) assert len(result) == 2 assert result[0].value == "foobar1" @@ -48,8 +37,7 @@ def test_run_with_namespace(self): def test_run_with_namespace_overrides(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) module = VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver, - query_params={"namespace": "test"} + vector_store_driver=vector_store_driver, query_params={"namespace": "test"} ) vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test") @@ -57,17 +45,15 @@ def test_run_with_namespace_overrides(self): result1 = module.run( RagContext( - query="test", - module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "empty"}}} + query="test", module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "empty"}}} ) ) result2 = module.run( RagContext( - query="test", - module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "test"}}} + query="test", module_params={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "test"}}} ) ) assert len(result1) == 0 - assert len(result2) == 2 \ No newline at end of file + assert len(result2) == 2 diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index 7df6582d6..a39c0c2f1 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -18,9 +18,7 @@ def engine(self): ) ] ), - response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver()) - ), + response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())), ) def test_module_name_uniqueness(self): @@ -30,14 +28,8 @@ def test_module_name_uniqueness(self): RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule( - name="test", - vector_store_driver=vector_store_driver - ), - VectorStoreRetrievalRagModule( - name="test", - vector_store_driver=vector_store_driver - ) + VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), ] ) ) @@ -45,14 +37,8 @@ def test_module_name_uniqueness(self): assert RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule( - name="test1", - vector_store_driver=vector_store_driver - ), - VectorStoreRetrievalRagModule( - name="test2", - vector_store_driver=vector_store_driver - ) + VectorStoreRetrievalRagModule(name="test1", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test2", vector_store_driver=vector_store_driver), ] ) )