Skip to content

Commit

Permalink
Change structure to structure_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed May 3, 2024
1 parent 9f0b78e commit fd616f3
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/griptape-framework/drivers/structure-run-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ joke_coordinator = Pipeline(
tasks=[
StructureRunTask(
driver=LocalStructureRunDriver(
structure=joke_teller,
structure_factory=lambda: joke_teller,
),
),
StructureRunTask(
"Rewrite this joke: {{ parent_output }}",
driver=LocalStructureRunDriver(
structure=joke_rewriter,
structure_factory=lambda: joke_rewriter,
),
),
]
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ team = Pipeline(
Identify key trends, breakthrough technologies, and potential industry impacts.""",
),
driver=LocalStructureRunDriver(
structure=researcher,
structure_factory=lambda: researcher,
),
),
StructureRunTask(
Expand All @@ -795,7 +795,7 @@ team = Pipeline(
"{{parent_output}}",
),
driver=LocalStructureRunDriver(
structure=writer,
structure_factory=lambda: writer,
),
),
],
Expand Down
10 changes: 5 additions & 5 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

from attrs import define, field

Expand All @@ -12,12 +12,12 @@

@define
class LocalStructureRunDriver(BaseStructureRunDriver):
structure: Structure = field(kw_only=True)
structure_factory: Callable[[], Structure] = field(kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
self.structure.run(*[arg.value for arg in args])
structure_factory = self.structure_factory().run(*[arg.value for arg in args])

if self.structure.output_task.output is not None:
return self.structure.output_task.output
if structure_factory.output_task.output is not None:
return structure_factory.output_task.output
else:
return InfoArtifact("No output found in response")
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestLocalStructureRunDriver:
@pytest.fixture
def driver(self):
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output"))
driver = LocalStructureRunDriver(structure=agent)
driver = LocalStructureRunDriver(structure_factory=lambda: agent)

return driver

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_structure_run_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TestStructureRunTask:
def test_run(self):
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output"))
pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output"))
driver = LocalStructureRunDriver(structure=agent)
driver = LocalStructureRunDriver(structure_factory=lambda: agent)

task = StructureRunTask(driver=driver)

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/tools/test_structure_run_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def client(self):
driver = MockPromptDriver()
agent = Agent(prompt_driver=driver)

return StructureRunClient(description="foo bar", driver=LocalStructureRunDriver(structure=agent))
return StructureRunClient(
description="foo bar", driver=LocalStructureRunDriver(structure_factory=lambda: agent)
)

def test_run_structure(self, client):
assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output"
Expand Down

0 comments on commit fd616f3

Please sign in to comment.