diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e016228c..e7c2f70be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Unique name generation for all `RagEngine` modules. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. +- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. +- **BREAKING**: Removed before and after response modules from `ResponseRagStage`. +- **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 8b0adfb9a..b5d2b0a01 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -40,7 +40,7 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index ee309cba2..2ac184a22 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -22,7 +22,7 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) vector_store_tool = RagClient( diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 76638113f..d24eb9427 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -22,7 +22,7 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index 435a5d470..c257cd4df 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -4,19 +4,24 @@ from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage from griptape.loaders import WebLoader +from griptape.rules import Rule, Ruleset prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) - artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") + if isinstance(artifacts, ErrorArtifact): - raise ValueError(artifacts.value) + raise Exception(artifacts.value) -vector_store.upsert_text_artifacts({"griptape": artifacts}) +vector_store.upsert_text_artifacts( + { + "griptape": artifacts, + } +) rag_engine = RagEngine( - query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="English")]), + query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]), retrieval_stage=RetrievalRagStage( max_chunks=5, retrieval_modules=[ @@ -25,7 +30,13 @@ ) ], ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)), + response_stage=ResponseRagStage( + response_modules=[ + PromptResponseRagModule( + prompt_driver=prompt_driver, rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])] + ) + ] + ), ) rag_context = RagContext( @@ -33,4 +44,4 @@ module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}}, ) -print(rag_engine.process(rag_context).output.to_text()) +print(rag_engine.process(rag_context).outputs[0].to_text()) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 70c3bde55..7bbc5614a 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -34,7 +34,7 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ), retrieval_rag_module_name="VectorStoreRetrievalRagModule", diff --git a/docs/griptape-framework/structures/src/tasks_9.py b/docs/griptape-framework/structures/src/tasks_9.py index 6fca66cc5..1033f1b2f 100644 --- a/docs/griptape-framework/structures/src/tasks_9.py +++ b/docs/griptape-framework/structures/src/tasks_9.py @@ -29,7 +29,7 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ), ) diff --git a/docs/griptape-tools/official-tools/src/rag_client_1.py b/docs/griptape-tools/official-tools/src/rag_client_1.py index 8751e80b1..01e71e253 100644 --- a/docs/griptape-tools/official-tools/src/rag_client_1.py +++ b/docs/griptape-tools/official-tools/src/rag_client_1.py @@ -27,7 +27,7 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ), ) diff --git a/griptape/engines/rag/modules/__init__.py b/griptape/engines/rag/modules/__init__.py index be66082f0..ace9f4a3b 100644 --- a/griptape/engines/rag/modules/__init__.py +++ b/griptape/engines/rag/modules/__init__.py @@ -10,8 +10,6 @@ from .response.base_after_response_rag_module import BaseAfterResponseRagModule from .response.base_response_rag_module import BaseResponseRagModule from .response.prompt_response_rag_module import PromptResponseRagModule -from .response.rulesets_before_response_rag_module import RulesetsBeforeResponseRagModule -from .response.metadata_before_response_rag_module import MetadataBeforeResponseRagModule from .response.text_chunks_response_rag_module import TextChunksResponseRagModule from .response.footnote_prompt_response_rag_module import FootnotePromptResponseRagModule @@ -28,8 +26,6 @@ "BaseAfterResponseRagModule", "BaseResponseRagModule", "PromptResponseRagModule", - "RulesetsBeforeResponseRagModule", - "MetadataBeforeResponseRagModule", "TextChunksResponseRagModule", "FootnotePromptResponseRagModule", ] diff --git a/griptape/engines/rag/modules/base_rag_module.py b/griptape/engines/rag/modules/base_rag_module.py index 829a24565..01d6d1b1d 100644 --- a/griptape/engines/rag/modules/base_rag_module.py +++ b/griptape/engines/rag/modules/base_rag_module.py @@ -1,5 +1,6 @@ from __future__ import annotations +import uuid from abc import ABC from concurrent import futures from typing import TYPE_CHECKING, Any, Callable, Optional @@ -14,7 +15,9 @@ @define(kw_only=True) class BaseRagModule(ABC): - name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) + name: str = field( + default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True + ) futures_executor_fn: Callable[[], futures.Executor] = field( default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), ) diff --git a/griptape/engines/rag/modules/response/base_response_rag_module.py b/griptape/engines/rag/modules/response/base_response_rag_module.py index 30ab82201..1bd3ddeb7 100644 --- a/griptape/engines/rag/modules/response/base_response_rag_module.py +++ b/griptape/engines/rag/modules/response/base_response_rag_module.py @@ -2,6 +2,7 @@ from attrs import define +from griptape.artifacts import BaseArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseRagModule @@ -9,4 +10,4 @@ @define(kw_only=True) class BaseResponseRagModule(BaseRagModule, ABC): @abstractmethod - def run(self, context: RagContext) -> RagContext: ... + def run(self, context: RagContext) -> BaseArtifact: ... diff --git a/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py b/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py deleted file mode 100644 index d2d546213..000000000 --- a/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.engines.rag.modules import BaseBeforeResponseRagModule -from griptape.utils import J2 - -if TYPE_CHECKING: - from griptape.engines.rag import RagContext - - -@define(kw_only=True) -class MetadataBeforeResponseRagModule(BaseBeforeResponseRagModule): - metadata: Optional[str] = field(default=None) - - def run(self, context: RagContext) -> RagContext: - context_metadata = self.get_context_param(context, "metadata") - metadata = self.metadata if context_metadata is None else context_metadata - - if metadata is not None: - context.before_query.append(J2("engines/rag/modules/response/metadata/system.j2").render(metadata=metadata)) - - return context diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 1cdb9a6f0..99bdf5f5e 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -1,27 +1,30 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact from griptape.engines.rag.modules import BaseResponseRagModule +from griptape.mixins import RuleMixin from griptape.utils import J2 if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.engines.rag import RagContext @define(kw_only=True) -class PromptResponseRagModule(BaseResponseRagModule): - answer_token_offset: int = field(default=400) +class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): prompt_driver: BasePromptDriver = field() + answer_token_offset: int = field(default=400) + metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) - def run(self, context: RagContext) -> RagContext: + def run(self, context: RagContext) -> BaseArtifact: query = context.query tokenizer = self.prompt_driver.tokenizer included_chunks = [] @@ -45,15 +48,17 @@ def run(self, context: RagContext) -> RagContext: output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact() if isinstance(output, TextArtifact): - context.output = output + return output else: raise ValueError("Prompt driver did not return a TextArtifact") - return context - def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: - return J2("engines/rag/modules/response/prompt/system.j2").render( - text_chunks=[c.to_text() for c in artifacts], - before_system_prompt="\n\n".join(context.before_query), - after_system_prompt="\n\n".join(context.after_query), - ) + params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} + + if len(self.all_rulesets) > 0: + params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets) + + if self.metadata is not None: + params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) + + return J2("engines/rag/modules/response/prompt/system.j2").render(**params) diff --git a/griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py b/griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py deleted file mode 100644 index 81b8410ce..000000000 --- a/griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.engines.rag.modules import BaseBeforeResponseRagModule -from griptape.utils import J2 - -if TYPE_CHECKING: - from griptape.engines.rag import RagContext - from griptape.rules import Ruleset - - -@define -class RulesetsBeforeResponseRagModule(BaseBeforeResponseRagModule): - rulesets: list[Ruleset] = field(kw_only=True) - - def run(self, context: RagContext) -> RagContext: - context.before_query.append(J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)) - - return context diff --git a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py index fd57b3905..35da0592b 100644 --- a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py +++ b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py @@ -1,13 +1,11 @@ from attrs import define -from griptape.artifacts import ListArtifact +from griptape.artifacts import BaseArtifact, ListArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseResponseRagModule @define(kw_only=True) class TextChunksResponseRagModule(BaseResponseRagModule): - def run(self, context: RagContext) -> RagContext: - context.output = ListArtifact(context.text_chunks) - - return context + def run(self, context: RagContext) -> BaseArtifact: + return ListArtifact(context.text_chunks) diff --git a/griptape/engines/rag/rag_context.py b/griptape/engines/rag/rag_context.py index 3146070c2..1ddfbb1b0 100644 --- a/griptape/engines/rag/rag_context.py +++ b/griptape/engines/rag/rag_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import define, field @@ -22,7 +22,7 @@ class RagContext(SerializableMixin): before_query: An optional list of strings to add before the query in response modules. after_query: An optional list of strings to add after the query in response modules. text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage. - output: Final output from the response stage. + outputs: List of outputs from the response stage. """ query: str = field(metadata={"serializable": True}) @@ -30,7 +30,7 @@ class RagContext(SerializableMixin): before_query: list[str] = field(factory=list, metadata={"serializable": True}) after_query: list[str] = field(factory=list, metadata={"serializable": True}) text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True}) - output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True}) + outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True}) def get_references(self) -> list[Reference]: return utils.references_from_artifacts(self.text_chunks) diff --git a/griptape/engines/rag/stages/query_rag_stage.py b/griptape/engines/rag/stages/query_rag_stage.py index 97a6c2e2d..ebde3170c 100644 --- a/griptape/engines/rag/stages/query_rag_stage.py +++ b/griptape/engines/rag/stages/query_rag_stage.py @@ -23,7 +23,7 @@ def modules(self) -> Sequence[BaseRagModule]: return self.query_modules def run(self, context: RagContext) -> RagContext: - logging.info("QueryStage: running %s query generation modules sequentially", len(self.query_modules)) + logging.info("QueryRagStage: running %s query generation modules sequentially", len(self.query_modules)) [qm.run(context) for qm in self.query_modules] diff --git a/griptape/engines/rag/stages/response_rag_stage.py b/griptape/engines/rag/stages/response_rag_stage.py index b63b5bc21..4bc0b2be5 100644 --- a/griptape/engines/rag/stages/response_rag_stage.py +++ b/griptape/engines/rag/stages/response_rag_stage.py @@ -5,13 +5,12 @@ from attrs import define, field +from griptape import utils from griptape.engines.rag.stages import BaseRagStage if TYPE_CHECKING: from griptape.engines.rag import RagContext from griptape.engines.rag.modules import ( - BaseAfterResponseRagModule, - BaseBeforeResponseRagModule, BaseRagModule, BaseResponseRagModule, ) @@ -19,35 +18,22 @@ @define(kw_only=True) class ResponseRagStage(BaseRagStage): - before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list) - response_module: BaseResponseRagModule = field() - after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list) + response_modules: list[BaseResponseRagModule] = field() @property def modules(self) -> list[BaseRagModule]: ms = [] - ms.extend(self.before_response_modules) - ms.extend(self.after_response_modules) - - if self.response_module is not None: - ms.append(self.response_module) + ms.extend(self.response_modules) return ms def run(self, context: RagContext) -> RagContext: - logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules)) - - for generator in self.before_response_modules: - context = generator.run(context) - - logging.info("GenerationStage: running generation module") - - context = self.response_module.run(context) + logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) - logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules)) + with self.futures_executor_fn() as executor: + results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules]) - for generator in self.after_response_modules: - context = generator.run(context) + context.outputs = results return context diff --git a/griptape/engines/rag/stages/retrieval_rag_stage.py b/griptape/engines/rag/stages/retrieval_rag_stage.py index 50b84abfc..fa618a7ff 100644 --- a/griptape/engines/rag/stages/retrieval_rag_stage.py +++ b/griptape/engines/rag/stages/retrieval_rag_stage.py @@ -33,7 +33,7 @@ def modules(self) -> list[BaseRagModule]: return ms def run(self, context: RagContext) -> RagContext: - logging.info("RetrievalStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) + logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) with self.futures_executor_fn() as executor: results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.retrieval_modules]) @@ -47,7 +47,7 @@ def run(self, context: RagContext) -> RagContext: chunks_after_dedup = len(results) logging.info( - "RetrievalStage: deduplicated %s " "chunks (%s - %s)", + "RetrievalRagStage: deduplicated %s " "chunks (%s - %s)", chunks_before_dedup - chunks_after_dedup, chunks_before_dedup, chunks_after_dedup, @@ -56,7 +56,7 @@ def run(self, context: RagContext) -> RagContext: context.text_chunks = [a for a in results if isinstance(a, TextArtifact)] if self.rerank_module: - logging.info("RetrievalStage: running rerank module on %s chunks", chunks_after_dedup) + logging.info("RetrievalRagStage: running rerank module on %s chunks", chunks_after_dedup) context.text_chunks = [a for a in self.rerank_module.run(context) if isinstance(a, TextArtifact)] diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8e66c5aba..62a517bc9 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -52,7 +52,7 @@ def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifac if self.retrieval_rag_module_name is None: raise ValueError("retrieval_rag_module_name is not set") - result = self.rag_engine.process( + outputs = self.rag_engine.process( RagContext( query=query, module_configs={ @@ -64,9 +64,9 @@ def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifac }, }, ), - ).output + ).outputs - if result is None: - return InfoArtifact("Empty output") + if len(outputs) > 0: + return outputs[0] else: - return result + return InfoArtifact("Empty output") diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index d68457ebc..8f095dfeb 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -22,9 +22,7 @@ from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( - MetadataBeforeResponseRagModule, PromptResponseRagModule, - RulesetsBeforeResponseRagModule, VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage @@ -180,11 +178,9 @@ def default_rag_engine(self) -> RagEngine: retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=self.config.vector_store_driver)], ), response_stage=ResponseRagStage( - before_response_modules=[ - RulesetsBeforeResponseRagModule(rulesets=self.rulesets), - MetadataBeforeResponseRagModule(), + response_modules=[ + PromptResponseRagModule(prompt_driver=self.config.prompt_driver, rulesets=self.rulesets) ], - response_module=PromptResponseRagModule(prompt_driver=self.config.prompt_driver), ), ) diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index 3f88f34d1..2f44fdfa4 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact from griptape.tasks import BaseTextInputTask if TYPE_CHECKING: @@ -29,9 +29,9 @@ def rag_engine(self, value: RagEngine) -> None: self._rag_engine = value def run(self) -> BaseArtifact: - result = self.rag_engine.process_query(self.input.to_text()).output + outputs = self.rag_engine.process_query(self.input.to_text()).outputs - if result is None: - return ErrorArtifact("empty output") + if len(outputs) > 0: + return ListArtifact(outputs) else: - return result + return ErrorArtifact("empty output") diff --git a/griptape/templates/engines/rag/modules/response/prompt/system.j2 b/griptape/templates/engines/rag/modules/response/prompt/system.j2 index 1fa9d8c12..38b0297d5 100644 --- a/griptape/templates/engines/rag/modules/response/prompt/system.j2 +++ b/griptape/templates/engines/rag/modules/response/prompt/system.j2 @@ -1,6 +1,10 @@ You are an expert Q&A system. Always answer the question using the provided context information, and not prior knowledge. Always be truthful. Don't make up facts. You can answer questions by searching through text chunks. -{% if before_system_prompt %} -{{ before_system_prompt }} +{% if rulesets %} +{{ rulesets }} + +{% endif %} +{% if metadata %} +{{ metadata }} {% endif %} Use the following list of text chunks to respond. If there are no text chunks available or text chunks don't have relevant information respond with "I could not find an answer." diff --git a/griptape/tools/rag_client/tool.py b/griptape/tools/rag_client/tool.py index bbdef8159..613e254af 100644 --- a/griptape/tools/rag_client/tool.py +++ b/griptape/tools/rag_client/tool.py @@ -5,7 +5,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -35,11 +35,12 @@ def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] try: - result = self.rag_engine.process_query(query) + outputs = self.rag_engine.process_query(query).outputs - if result.output is None: - return ErrorArtifact("query output is empty") + if len(outputs) > 0: + return ListArtifact(outputs) else: - return result.output + return ErrorArtifact("query output is empty") + except Exception as e: return ErrorArtifact(f"error querying: {e}") diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 385cf0c04..4d0aad139 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -13,7 +13,7 @@ def module(self): return FootnotePromptResponseRagModule(prompt_driver=MockPromptDriver()) def test_run(self, module): - assert module.run(RagContext(query="test")).output.value == "mock output" + assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): system_message = module.default_system_template_generator( diff --git a/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py deleted file mode 100644 index 9519c8017..000000000 --- a/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py +++ /dev/null @@ -1,21 +0,0 @@ -from griptape.engines.rag import RagContext -from griptape.engines.rag.modules import MetadataBeforeResponseRagModule - - -class TestMetadataBeforeResponseRagModule: - def test_run(self): - module = MetadataBeforeResponseRagModule(name="foo") - - assert ( - "foo" in module.run(RagContext(module_configs={"foo": {"metadata": "foo"}}, query="test")).before_query[0] - ) - - def test_run_with_override(self): - module = MetadataBeforeResponseRagModule(name="foo", metadata="bar") - - assert ( - "bar" - in module.run( - RagContext(module_configs={"foo": {"query_params": {"metadata": "foo"}}}, query="test") - ).before_query[0] - ) diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index 2f8a912e2..cc8d35f0e 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -3,20 +3,25 @@ from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule +from griptape.rules import Rule, Ruleset from tests.mocks.mock_prompt_driver import MockPromptDriver class TestPromptResponseRagModule: @pytest.fixture() def module(self): - return PromptResponseRagModule(prompt_driver=MockPromptDriver()) + return PromptResponseRagModule( + prompt_driver=MockPromptDriver(), + rulesets=[Ruleset(name="test", rules=[Rule("*RULESET*")])], + metadata="*META*", + ) def test_run(self, module): - assert module.run(RagContext(query="test")).output.value == "mock output" + assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): system_message = module.default_system_template_generator( - RagContext(query="test", before_query=["*RULESET*", "*META*"], after_query=[]), + RagContext(query="test"), artifacts=[TextArtifact("*TEXT SEGMENT 1*"), TextArtifact("*TEXT SEGMENT 2*")], ) diff --git a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py deleted file mode 100644 index bc85cf266..000000000 --- a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py +++ /dev/null @@ -1,10 +0,0 @@ -from griptape.engines.rag import RagContext -from griptape.engines.rag.modules import RulesetsBeforeResponseRagModule -from griptape.rules import Rule, Ruleset - - -class TestRulesetsBeforeResponseRagModule: - def test_run(self): - module = RulesetsBeforeResponseRagModule(rulesets=[Ruleset(name="test ruleset", rules=[Rule("test rule")])]) - - assert "test rule" in module.run(RagContext(query="test")).before_query[0] diff --git a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py index ae4410b2c..05b8042e6 100644 --- a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py @@ -13,4 +13,4 @@ def module(self): def test_run(self, module): text_chunks = [TextArtifact("foo"), TextArtifact("bar")] - assert module.run(RagContext(query="test", text_chunks=text_chunks)).output.value == text_chunks + assert module.run(RagContext(query="test", text_chunks=text_chunks)).value == text_chunks diff --git a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py index 9fecc3c0e..96a91280e 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py @@ -40,22 +40,18 @@ def test_run_with_namespace(self): def test_run_with_namespace_overrides(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) module = VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver, query_params={"namespace": "test"} + name="TestModule", vector_store_driver=vector_store_driver, query_params={"namespace": "test"} ) vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test") vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test") result1 = module.run( - RagContext( - query="test", module_configs={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "empty"}}} - ) + RagContext(query="test", module_configs={"TestModule": {"query_params": {"namespace": "empty"}}}) ) result2 = module.run( - RagContext( - query="test", module_configs={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "test"}}} - ) + RagContext(query="test", module_configs={"TestModule": {"query_params": {"namespace": "test"}}}) ) assert len(result1) == 0 diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index c3d728bb3..964a52650 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -19,7 +19,9 @@ def engine(self): ) ] ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())), + response_stage=ResponseRagStage( + response_modules=[PromptResponseRagModule(prompt_driver=MockPromptDriver())] + ), ) def test_module_name_uniqueness(self): @@ -45,7 +47,7 @@ def test_module_name_uniqueness(self): ) def test_process_query(self, engine): - assert engine.process_query("test").output.value == "mock output" + assert engine.process_query("test").outputs[0].value == "mock output" def test_process(self, engine): - assert engine.process(RagContext(query="test")).output.value == "mock output" + assert engine.process(RagContext(query="test")).outputs[0].value == "mock output" diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index a09ad0f9a..e3d9034c4 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -238,7 +238,7 @@ def test_task_memory_defaults(self): storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) - assert storage.rag_engine.response_stage.response_module.prompt_driver == prompt_driver + assert storage.rag_engine.response_stage.response_modules[0].prompt_driver == prompt_driver assert ( storage.rag_engine.retrieval_stage.retrieval_modules[0].vector_store_driver.embedding_driver == embedding_driver diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index b205d385a..dc8603a2a 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -15,7 +15,7 @@ def task(self): input="test", rag_engine=RagEngine( response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver()) + response_modules=[PromptResponseRagModule(prompt_driver=MockPromptDriver())] ) ), ) diff --git a/tests/unit/tools/test_rag_client.py b/tests/unit/tools/test_rag_client.py index 9c6497b02..60a0df722 100644 --- a/tests/unit/tools/test_rag_client.py +++ b/tests/unit/tools/test_rag_client.py @@ -10,4 +10,4 @@ def test_search(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) tool = RagClient(description="Test", rag_engine=rag_engine(MockPromptDriver(), vector_store_driver)) - assert tool.search({"values": {"query": "test"}}).value == "mock output" + assert tool.search({"values": {"query": "test"}}).value[0].value == "mock output" diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index bad7f0d79..154f63ac4 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -34,5 +34,5 @@ def rag_engine(prompt_driver, vector_store_driver): retrieval_stage=RetrievalRagStage( retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=vector_store_driver)] ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)), + response_stage=ResponseRagStage(response_modules=[PromptResponseRagModule(prompt_driver=prompt_driver)]), )