Skip to content

Commit

Permalink
Fix bad rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 7, 2024
1 parent 97564a2 commit 4220e0f
Show file tree
Hide file tree
Showing 8 changed files with 9 additions and 13 deletions.
9 changes: 5 additions & 4 deletions griptape/config/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from attrs import define

from griptape.config.base_driver_config import BaseDriverConfig
from griptape.config.events_config import EventsConfig
from griptape.mixins.serializable_mixin import SerializableMixin

from .base_driver_config import BaseDriverConfig
from .logging_config import LoggingConfig

@define

@define(kw_only=True)
class BaseConfig(SerializableMixin, ABC):
drivers: BaseDriverConfig
events: EventsConfig
logging: LoggingConfig
2 changes: 0 additions & 2 deletions griptape/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from .base_config import BaseConfig
from .base_driver_config import BaseDriverConfig
from .events_config import EventsConfig
from .logging_config import LoggingConfig
from .openai_driver_config import OpenAiDriverConfig


@define
class _Config(BaseConfig):
drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True)
events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True)
logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True)


Expand Down
2 changes: 0 additions & 2 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def __process_run(self, prompt_stack: PromptStack) -> Message:
return result

def __process_stream(self, prompt_stack: PromptStack) -> Message:
from griptape.config import Config

delta_contents: dict[int, list[BaseDeltaMessageContent]] = {}
usage = DeltaMessage.Usage()

Expand Down
1 change: 1 addition & 0 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
from griptape.common import ToolAction
from griptape.config import Config
from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent
from griptape.mixins import ActionsSubtaskOriginMixin
from griptape.tasks import BaseTask
Expand Down
1 change: 1 addition & 0 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from attrs import Factory, define, field

from griptape.artifacts import ErrorArtifact
from griptape.config import Config
from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent

if TYPE_CHECKING:
Expand Down
2 changes: 0 additions & 2 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def run(self, *args) -> Iterator[TextArtifact]:
t.join()

def _run_structure(self, *args) -> None:
from griptape.config import Config

def event_handler(event: BaseEvent) -> None:
self._event_queue.put(event)

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/drivers/prompt/test_base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def test_run_via_pipeline_retries_failure(self, mock_config):

def test_run_via_pipeline_publishes_events(self, mocker):
mock_publish_event = mocker.patch.object(_EventBus, "publish_event")
driver = MockPromptDriver()
pipeline = Pipeline(prompt_driver=driver)
pipeline = Pipeline()
pipeline.add_task(PromptTask("test"))

pipeline.run()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def task(self):
agent = Agent(
tools=[MockTool()],
)
Config.event_listeners = [EventListener(handler=Mock())]
EventBus.event_listeners = [EventListener(handler=Mock())]

agent.add_task(MockTask("foobar", max_meta_memory_entries=2))

Expand Down

0 comments on commit 4220e0f

Please sign in to comment.