Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 23, 2024
1 parent 9f00ec3 commit 56b8fff
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Optional

from attrs import define, field
from schema import Schema

from griptape.artifacts import (
ActionArtifact,
Expand All @@ -25,6 +24,8 @@
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from schema import Schema

from griptape.tools import BaseTool


Expand Down
6 changes: 3 additions & 3 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_mode == "tool":
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["toolConfig"] = {
"toolChoice": {"any": {}},
}
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)
else:
raise ValueError(f"Unsupported native structured output mode: {self.native_structured_output_mode}")

Expand Down
8 changes: 4 additions & 4 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import NOTHING, Attribute, Factory, NothingType, define, field
from schema import Schema

from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
Expand All @@ -13,13 +14,12 @@
from griptape.memory.structure import Run
from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import Ruleset
from griptape.rules import JsonSchemaRule, Ruleset
from griptape.tasks import ActionsSubtask, BaseTask
from griptape.tools import StructuredOutputTool
from griptape.utils import J2

if TYPE_CHECKING:
from schema import Schema

from griptape.drivers import BasePromptDriver
from griptape.memory import TaskMemory
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
Expand Down Expand Up @@ -158,7 +158,7 @@ def prompt_stack(self) -> PromptStack:

if memory is not None:
# inserting at index 1 to place memory right after system prompt
memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0)
memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_contents else 0)

return stack

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from griptape.artifacts.json_artifact import JsonArtifact
from griptape.common import ToolAction
from griptape.structures import Agent
from griptape.tasks import ActionsSubtask, PromptTask
from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask
from tests.mocks.mock_tool.tool import MockTool


Expand Down

0 comments on commit 56b8fff

Please sign in to comment.