diff --git a/griptape/tools/rag/tool.py b/griptape/tools/rag/tool.py index 8608493d1..aab52e6c0 100644 --- a/griptape/tools/rag/tool.py +++ b/griptape/tools/rag/tool.py @@ -5,7 +5,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact +from griptape.artifacts import ErrorArtifact, ListArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -31,11 +31,18 @@ class RagTool(BaseTool): "schema": Schema({Literal("query", description="A natural language search query"): str}), }, ) - def search(self, params: dict) -> BaseArtifact: + def search(self, params: dict) -> ListArtifact | ErrorArtifact: query = params["values"]["query"] try: - outputs = self.rag_engine.process_query(query).outputs + artifacts = self.rag_engine.process_query(query).outputs + + outputs = [] + for artifact in artifacts: + if isinstance(artifact, ListArtifact): + outputs.extend(artifact.value) + else: + outputs.append(artifact) if len(outputs) > 0: return ListArtifact(outputs)