Skip to content

Commit

Permalink
Flatten rag tool outputs (#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 16, 2024
1 parent 4d71eaf commit 54efe3b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions griptape/tools/rag/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import define, field
from schema import Literal, Schema

from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.tools import BaseTool
from griptape.utils.decorators import activity

Expand All @@ -31,11 +31,18 @@ class RagTool(BaseTool):
"schema": Schema({Literal("query", description="A natural language search query"): str}),
},
)
def search(self, params: dict) -> BaseArtifact:
def search(self, params: dict) -> ListArtifact | ErrorArtifact:
query = params["values"]["query"]

try:
outputs = self.rag_engine.process_query(query).outputs
artifacts = self.rag_engine.process_query(query).outputs

outputs = []
for artifact in artifacts:
if isinstance(artifact, ListArtifact):
outputs.extend(artifact.value)
else:
outputs.append(artifact)

if len(outputs) > 0:
return ListArtifact(outputs)
Expand Down

0 comments on commit 54efe3b

Please sign in to comment.