Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 4, 2024
1 parent eae1bc1 commit 47381bd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from .event_listener.webhook_event_listener_driver import WebhookEventListenerDriver
from .event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver
from .event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver
from .event_listener.pusher_event_listener_driver import PusherEventListenerDriver

from .file_manager.base_file_manager_driver import BaseFileManagerDriver
from .file_manager.local_file_manager_driver import LocalFileManagerDriver
Expand Down Expand Up @@ -189,6 +190,7 @@
"WebhookEventListenerDriver",
"AwsIotCoreEventListenerDriver",
"GriptapeCloudEventListenerDriver",
"PusherEventListenerDriver",
"BaseFileManagerDriver",
"LocalFileManagerDriver",
"AmazonS3FileManagerDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from pytest import fixture
from tests.mocks.mock_event import MockEvent
from griptape.drivers.event_listener.pusher_event_listener_driver import PusherEventListenerDriver
from griptape.drivers import PusherEventListenerDriver
from unittest.mock import Mock


class TestPusherEventListenerDriver:
@fixture(autouse=True)
def mock_post(self, mocker):
mock_pusher_client = mocker.patch("pusher.Pusher")
mock_pusher_client.return_value.trigger.return_value = Mock()
mock_pusher_client.return_value.trigger_batch.return_value = Mock()

return mock_pusher_client

@fixture()
def driver(self):
return PusherEventListenerDriver(
Expand All @@ -21,5 +30,13 @@ def test_init(self, driver):
def test_try_publish_event_payload(self, driver):
driver.try_publish_event_payload(MockEvent().to_dict())

assert driver.pusher_client.trigger.called_with(
channels="test-channel", event_name="test-event", data=MockEvent().to_dict()
)

def test_try_publish_event_payload_batch(self, driver):
driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)])

assert driver.pusher_client.trigger_batch.called_with(
[{"channel": "test-channel", "name": "test-event", "data": MockEvent().to_dict()} for _ in range(3)]
)

0 comments on commit 47381bd

Please sign in to comment.