Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed May 2, 2024
1 parent ed82af3 commit 96005fe
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 16 deletions.
21 changes: 19 additions & 2 deletions docs/griptape-framework/drivers/structure-run-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ When combined with the [StructureRunTask](../../reference/griptape/tasks/structu

## Local Structure Run Driver

The [LocalStructureRunDriver](../../reference/griptape/drivers/structure-run/local-structure-run-driver.md) is used to run Griptape Structures in the same runtime environment as the code that is running the Structure.

```python
from griptape.drivers import LocalStructureRunDriver
from griptape.rules import Rule
Expand Down Expand Up @@ -47,11 +49,15 @@ joke_coordinator.run("Tell me a joke")

## Griptape Cloud Structure Run Driver

The [GriptapeCloudStructureRunDriver](../../reference/griptape/drivers/structure-run/griptape-cloud-structure-run-driver.md) is used to run Griptape Structures in the Griptape Cloud.


```python
import os

from griptape.drivers import GriptapeCloudStructureRunDriver, LocalStructureRunDriver
from griptape.structures import Pipeline, Agent
from griptape.rules import Rule
from griptape.tasks import StructureRunTask

base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"]
Expand All @@ -62,8 +68,19 @@ structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"]
pipeline = Pipeline(
tasks=[
StructureRunTask(
"Think of a question related to RAG.",
driver=LocalStructureRunDriver(structure=Agent()),
"Think of a question related to Retrieval Augmented Generation.",
driver=LocalStructureRunDriver(
structure=Agent(
rules=[
Rule(
value="You are an expert in Retrieval Augmented Generation.",
),
Rule(
value="Only output your answer, no other information.",
),
]
)
),
),
StructureRunTask(
"{{ parent_output }}",
Expand Down
7 changes: 2 additions & 5 deletions griptape/drivers/structure_run/base_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

from attrs import define

from griptape.artifacts import BaseArtifact, ErrorArtifact
from griptape.artifacts import BaseArtifact


@define
class BaseStructureRunDriver(ABC):
def run(self, *args: BaseArtifact) -> BaseArtifact:
try:
return self.try_run(*args)
except Exception as e:
return ErrorArtifact(str(e))
return self.try_run(*args)

@abstractmethod
def try_run(self, *args: BaseArtifact) -> BaseArtifact:
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
class LocalStructureRunDriver(BaseStructureRunDriver):
structure: Structure = field(kw_only=True)

def try_run(self, *args) -> BaseArtifact:
self.structure.run(*args)
def try_run(self, *args: BaseArtifact) -> BaseArtifact:
self.structure.run(*[arg.value for arg in args])

if self.structure.output_task.output is not None:
return self.structure.output_task.output
Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/structure_run_client/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ class StructureRunClient(BaseTool):
}
)
def run_structure(self, params: dict) -> BaseArtifact:
args: str = params["values"]["args"]
args: list[str] = params["values"]["args"]

return self.driver.run(*[TextArtifact(arg) for arg in args])
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ def driver(self, mocker):
mocker.patch("requests.post", return_value=mock_response)

mock_response = mocker.Mock()
mock_response.json.return_value = {"description": "fizz buzz", "output": "fooey booey", "status": "SUCCEEDED"}
mock_response.json.return_value = {
"description": "fizz buzz",
"output": TextArtifact("foo bar").to_dict(),
"status": "SUCCEEDED",
}
mocker.patch("requests.get", return_value=mock_response)

return GriptapeCloudStructureRunDriver(base_url="https://api.griptape.ai", api_key="foo bar", structure_id="1")
return GriptapeCloudStructureRunDriver(
base_url="https://cloud-foo.griptape.ai", api_key="foo bar", structure_id="1"
)

def test_run(self, driver):
assert isinstance(driver.run("foo bar"), TextArtifact)
result = driver.run(TextArtifact("foo bar"))
assert isinstance(result, TextArtifact)
assert result.value == "foo bar"
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import pytest
from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver
from griptape.tools import GriptapeStructureRunClient
from griptape.tools import StructureRunClient
from griptape.structures import Agent
from tests.mocks.mock_prompt_driver import MockPromptDriver


class TestGriptapeStructureRunClient:
class TestStructureRunClient:
@pytest.fixture
def client(self):
driver = MockPromptDriver()
agent = Agent(prompt_driver=driver)

return GriptapeStructureRunClient(description="foo bar", driver=LocalStructureRunDriver(structure=agent))
return StructureRunClient(description="foo bar", driver=LocalStructureRunDriver(structure=agent))

def test_run_structure(self, client):
assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output"
Expand Down

0 comments on commit 96005fe

Please sign in to comment.