Skip to content

Commit

Permalink
Rename factory fn
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed May 3, 2024
1 parent d3e2d1b commit 0196143
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/griptape-framework/drivers/structure-run-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ joke_coordinator = Pipeline(
tasks=[
StructureRunTask(
driver=LocalStructureRunDriver(
structure_factory=lambda: joke_teller,
structure_factory_fn=lambda: joke_teller,
),
),
StructureRunTask(
"Rewrite this joke: {{ parent_output }}",
driver=LocalStructureRunDriver(
structure_factory=lambda: joke_rewriter,
structure_factory_fn=lambda: joke_rewriter,
),
),
]
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ team = Pipeline(
Identify key trends, breakthrough technologies, and potential industry impacts.""",
),
driver=LocalStructureRunDriver(
structure_factory=lambda: researcher,
structure_factory_fn=lambda: researcher,
),
),
StructureRunTask(
Expand All @@ -795,7 +795,7 @@ team = Pipeline(
"{{parent_output}}",
),
driver=LocalStructureRunDriver(
structure_factory=lambda: writer,
structure_factory_fn=lambda: writer,
),
),
],
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

@define
class LocalStructureRunDriver(BaseStructureRunDriver):
structure_factory: Callable[[], Structure] = field(kw_only=True)
structure_factory_fn: Callable[[], Structure] = field(kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
structure_factory = self.structure_factory().run(*[arg.value for arg in args])
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])

if structure_factory.output_task.output is not None:
return structure_factory.output_task.output
if structure_factory_fn.output_task.output is not None:
return structure_factory_fn.output_task.output
else:
return InfoArtifact("No output found in response")
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestLocalStructureRunDriver:
@pytest.fixture
def driver(self):
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output"))
driver = LocalStructureRunDriver(structure_factory=lambda: agent)
driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent)

return driver

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_structure_run_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TestStructureRunTask:
def test_run(self):
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output"))
pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output"))
driver = LocalStructureRunDriver(structure_factory=lambda: agent)
driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent)

task = StructureRunTask(driver=driver)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tools/test_structure_run_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def client(self):
agent = Agent(prompt_driver=driver)

return StructureRunClient(
description="foo bar", driver=LocalStructureRunDriver(structure_factory=lambda: agent)
description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent)
)

def test_run_structure(self, client):
Expand Down

0 comments on commit 0196143

Please sign in to comment.