From a83fe51289f73659ad10117000ad4d8472268ab9 Mon Sep 17 00:00:00 2001 From: Andrew French Date: Wed, 15 May 2024 12:17:59 -0700 Subject: [PATCH] Release v0.25.1 (#785) Co-authored-by: Collin Dutter Co-authored-by: dylanholmes <4370153+dylanholmes@users.noreply.github.com> Co-authored-by: Zach Giordano <32624672+zachgiordano@users.noreply.github.com> --- CHANGELOG.md | 15 +++ .../drivers/vector-store-drivers.md | 2 +- griptape/config/openai_structure_config.py | 2 +- .../amazon_sqs_event_listener_driver.py | 8 ++ .../aws_iot_core_event_listener_driver.py | 3 + .../base_event_listener_driver.py | 32 +++-- .../griptape_cloud_event_listener_driver.py | 6 + .../webhook_event_listener_driver.py | 4 + .../vector/redis_vector_store_driver.py | 12 +- griptape/events/base_event.py | 6 +- griptape/events/event_listener.py | 6 +- griptape/structures/structure.py | 7 +- griptape/tasks/actions_subtask.py | 114 +++++++++--------- griptape/tokenizers/openai_tokenizer.py | 7 +- poetry.lock | 79 ++++++------ pyproject.toml | 4 +- tests/mocks/mock_event.py | 2 +- tests/mocks/mock_event_listener_driver.py | 5 +- .../config/test_openai_structure_config.py | 10 +- .../test_amazon_sqs_event_listener_driver.py | 3 + .../test_aws_iot_event_listener_driver.py | 3 + .../test_base_event_listener_driver.py | 25 ++++ ...st_griptape_cloud_event_listener_driver.py | 11 ++ .../test_webhook_event_listener_driver.py | 11 ++ .../vector/test_redis_vector_store_driver.py | 80 +++++++++--- tests/unit/events/test_event_listener.py | 4 +- tests/unit/tasks/test_actions_subtask.py | 26 +++- 27 files changed, 344 insertions(+), 143 deletions(-) create mode 100644 tests/unit/drivers/event_listener/test_base_event_listener_driver.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e31647b7b..519708d18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## [0.25.1] - 2024-05-15 + +### Fixed +- Honor `namespace` in `RedisVectorStoreDriver.query()`. +- Correctly set the `meta`, `score`, and `vector` fields of query result returned from `RedisVectorStoreDriver.query()`. +- Standardize behavior between omitted and empty actions list when initializing `ActionsSubtask`. + +### Added +- Optional event batching on Event Listener Drivers. +- `id` field to all events. + +### Changed +- Default behavior of Event Listener Drivers to batch events. +- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. + ## [0.25.0] - 2024-05-06 ### Added diff --git a/docs/griptape-framework/drivers/vector-store-drivers.md b/docs/griptape-framework/drivers/vector-store-drivers.md index e2e85d2c9..73a416c84 100644 --- a/docs/griptape-framework/drivers/vector-store-drivers.md +++ b/docs/griptape-framework/drivers/vector-store-drivers.md @@ -316,7 +316,7 @@ print(result) The format for creating a vector index should be similar to the following: ``` -FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA tag TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE +FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE ``` ## OpenSearch Vector Store Driver diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 283fca2d1..64c32ecec 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -24,7 +24,7 @@ class OpenAiStructureConfig(BaseStructureConfig): global_drivers: StructureGlobalDriversConfig = field( default=Factory( lambda: StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"), embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"), diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index 24e3c9e1e..1c8132b67 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -20,3 +20,11 @@ class AmazonSqsEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + entries = [ + {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)} + for event_payload in event_payload_batch + ] + + self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py index 302fd91d5..c4fd72084 100644 --- a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -21,3 +21,6 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload)) + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index eec0fe320..8e7f827e9 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -14,19 +14,37 @@ @define class BaseEventListenerDriver(ABC): futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + batched: bool = field(default=True, kw_only=True) + batch_size: int = field(default=10, kw_only=True) - def publish_event(self, event: BaseEvent | dict) -> None: - if isinstance(event, dict): - self.futures_executor.submit(self._safe_try_publish_event_payload, event) - else: - self.futures_executor.submit(self._safe_try_publish_event_payload, event.to_dict()) + _batch: list[dict] = field(default=Factory(list), kw_only=True) + + @property + def batch(self) -> list[dict]: + return self._batch + + def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None: + self.futures_executor.submit(self._safe_try_publish_event, event, flush) @abstractmethod def try_publish_event_payload(self, event_payload: dict) -> None: ... - def _safe_try_publish_event_payload(self, event_payload: dict) -> None: + @abstractmethod + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + ... + + def _safe_try_publish_event(self, event: BaseEvent | dict, flush: bool) -> None: try: - self.try_publish_event_payload(event_payload) + event_payload = event if isinstance(event, dict) else event.to_dict() + + if self.batched: + self._batch.append(event_payload) + if len(self.batch) >= self.batch_size or flush: + self.try_publish_event_payload_batch(self.batch) + self._batch = [] + return + else: + self.try_publish_event_payload(event_payload) except Exception as e: logger.error(e) diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 461f06be9..2c4149ae7 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -41,3 +41,9 @@ def try_publish_event_payload(self, event_payload: dict) -> None: response = requests.post(url=url, json=event_payload, headers=self.headers) response.raise_for_status() + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events") + + response = requests.post(url=url, json=event_payload_batch, headers=self.headers) + response.raise_for_status() diff --git a/griptape/drivers/event_listener/webhook_event_listener_driver.py b/griptape/drivers/event_listener/webhook_event_listener_driver.py index 3803c86b6..242e5428a 100644 --- a/griptape/drivers/event_listener/webhook_event_listener_driver.py +++ b/griptape/drivers/event_listener/webhook_event_listener_driver.py @@ -15,3 +15,7 @@ class WebhookEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers) response.raise_for_status() + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers) + response.raise_for_status() diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index db99725a3..3772818ab 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -64,6 +64,9 @@ def upsert_vector( mapping["vector"] = np.array(vector, dtype=np.float32).tobytes() mapping["vec_string"] = bytes_vector + if namespace: + mapping["namespace"] = namespace + if meta: mapping["metadata"] = json.dumps(meta) @@ -120,8 +123,9 @@ def query( vector = self.embedding_driver.embed_string(query) + filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*" query_expression = ( - Query(f"*=>[KNN {count or 10} @vector $vector as score]") + Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]") .sort_by("score") .return_fields("id", "score", "metadata", "vec_string") .paging(0, count or 10) @@ -134,15 +138,15 @@ def query( query_results = [] for document in results: - metadata = getattr(document, "metadata", None) + metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None namespace = document.id.split(":")[0] if ":" in document.id else None vector_id = document.id.split(":")[1] if ":" in document.id else document.id - vector_float_list = json.loads(document["vec_string"]) if include_vectors else None + vector_float_list = json.loads(document.vec_string) if include_vectors else None query_results.append( BaseVectorStoreDriver.QueryResult( id=vector_id, vector=vector_float_list, - score=float(document["score"]), + score=float(document.score), meta=metadata, namespace=namespace, ) diff --git a/griptape/events/base_event.py b/griptape/events/base_event.py index d32defe96..48a48890e 100644 --- a/griptape/events/base_event.py +++ b/griptape/events/base_event.py @@ -1,11 +1,15 @@ from __future__ import annotations + import time +import uuid from abc import ABC -from attr import define, field, Factory + +from attr import Factory, define, field from griptape.mixins import SerializableMixin @define class BaseEvent(SerializableMixin, ABC): + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) timestamp: float = field(default=Factory(lambda: time.time()), kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index a6b692d4d..44d7b2d85 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -13,13 +13,13 @@ class EventListener: event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) - def publish_event(self, event: BaseEvent) -> None: + def publish_event(self, event: BaseEvent, flush: bool = False) -> None: event_types = self.event_types if event_types is None or type(event) in event_types: event_payload = self.handler(event) if self.driver is not None: if event_payload is not None and isinstance(event_payload, dict): - self.driver.publish_event(event_payload) + self.driver.publish_event(event_payload, flush=flush) else: - self.driver.publish_event(event) + self.driver.publish_event(event, flush=flush) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index ef9205db9..9cd28ab67 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -249,9 +249,9 @@ def remove_event_listener(self, event_listener: EventListener) -> None: else: raise ValueError("Event Listener not found.") - def publish_event(self, event: BaseEvent) -> None: + def publish_event(self, event: BaseEvent, flush: bool = False) -> None: for event_listener in self.event_listeners: - event_listener.publish_event(event) + event_listener.publish_event(event, flush) def context(self, task: BaseTask) -> dict[str, Any]: return {"args": self.execution_args, "structure": self} @@ -269,7 +269,8 @@ def after_run(self) -> None: structure_id=self.id, output_task_input=self.output_task.input, output_task_output=self.output_task.output, - ) + ), + flush=True, ) @abstractmethod diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6ed1bff5d..ef2c2ce6f 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -198,69 +198,73 @@ def __init_from_prompt(self, value: str) -> None: if self.thought is None and len(thought_matches) > 0: self.thought = thought_matches[-1] - if len(actions_matches) > 0: - try: - data = actions_matches[-1] - actions_list: list = json.loads(data, strict=False) + self.__parse_actions(actions_matches) + # If there are no actions to take but an answer is provided, set the answer as the output. + if len(self.actions) == 0 and self.output is None and len(answer_matches) > 0: + self.output = TextArtifact(answer_matches[-1]) + + def __parse_actions(self, actions_matches: list[str]) -> None: + if len(actions_matches) == 0: + return + + try: + data = actions_matches[-1] + actions_list: list = json.loads(data, strict=False) + + if isinstance(self.origin_task, ActionsSubtaskOriginMixin): + self.origin_task.actions_schema().validate(actions_list) + + for action_object in actions_list: + # Load action name; throw exception if the key is not present + action_tag = action_object["tag"] + + # Load action name; throw exception if the key is not present + action_name = action_object["name"] + + # Load action method; throw exception if the key is not present + action_path = action_object["path"] + + # Load optional input value; don't throw exceptions if key is not present + if "input" in action_object: + # The schema library has a bug, where something like `Or(str, None)` doesn't get + # correctly translated into JSON schema. For some optional input fields LLMs sometimes + # still provide null value, which trips up the validator. The temporary solution that + # works is to strip all key-values where value is null. + action_input = remove_null_values_in_dict_recursively(action_object["input"]) + else: + action_input = {} + + # Load the action itself if isinstance(self.origin_task, ActionsSubtaskOriginMixin): - self.origin_task.actions_schema().validate(actions_list) - - if not actions_list: - raise schema.SchemaError("Array item count 0 is less than minimum count of 1.") - - for action_object in actions_list: - # Load action name; throw exception if the key is not present - action_tag = action_object["tag"] - - # Load action name; throw exception if the key is not present - action_name = action_object["name"] - - # Load action method; throw exception if the key is not present - action_path = action_object["path"] - - # Load optional input value; don't throw exceptions if key is not present - if "input" in action_object: - # The schema library has a bug, where something like `Or(str, None)` doesn't get - # correctly translated into JSON schema. For some optional input fields LLMs sometimes - # still provide null value, which trips up the validator. The temporary solution that - # works is to strip all key-values where value is null. - action_input = remove_null_values_in_dict_recursively(action_object["input"]) - else: - action_input = {} - - # Load the action itself - if isinstance(self.origin_task, ActionsSubtaskOriginMixin): - tool = self.origin_task.find_tool(action_name) - else: - raise Exception( - "ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin." - ) - - new_action = ActionsSubtask.Action( - tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool + tool = self.origin_task.find_tool(action_name) + else: + raise Exception( + "ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin." ) - if new_action.tool: - if new_action.input: - self.__validate_action(new_action) + new_action = ActionsSubtask.Action( + tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool + ) - # Don't forget to add it to the subtask actions list! - self.actions.append(new_action) - except SyntaxError as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nSyntax error: {e}") + if new_action.tool: + if new_action.input: + self.__validate_action(new_action) - self.actions.append(self.__error_to_action(f"syntax error: {e}")) - except schema.SchemaError as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid action JSON: {e}") + # Don't forget to add it to the subtask actions list! + self.actions.append(new_action) + except SyntaxError as e: + self.structure.logger.error(f"Subtask {self.origin_task.id}\nSyntax error: {e}") - self.actions.append(self.__error_to_action(f"Action JSON validation error: {e}")) - except Exception as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nError parsing tool action: {e}") + self.actions.append(self.__error_to_action(f"syntax error: {e}")) + except schema.SchemaError as e: + self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid action JSON: {e}") - self.actions.append(self.__error_to_action(f"Action input parsing error: {e}")) - elif self.output is None and len(answer_matches) > 0: - self.output = TextArtifact(answer_matches[-1]) + self.actions.append(self.__error_to_action(f"Action JSON validation error: {e}")) + except Exception as e: + self.structure.logger.error(f"Subtask {self.origin_task.id}\nError parsing tool action: {e}") + + self.actions.append(self.__error_to_action(f"Action input parsing error: {e}")) def __error_to_action(self, error: str) -> Action: return ActionsSubtask.Action(tag="error", name="error", input={"error": error}) diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index 08c334e0c..dda8bfe15 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -10,7 +10,7 @@ class OpenAiTokenizer(BaseTokenizer): DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo" - DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4" + DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4o" DEFAULT_ENCODING = "cl100k_base" DEFAULT_MAX_TOKENS = 2049 DEFAULT_MAX_OUTPUT_TOKENS = 4096 @@ -18,6 +18,7 @@ class OpenAiTokenizer(BaseTokenizer): # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = { + "gpt-4o": 128000, "gpt-4-1106": 128000, "gpt-4-32k": 32768, "gpt-4": 8192, @@ -85,6 +86,7 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", + "gpt-4o-2024-05-13", }: tokens_per_message = 3 tokens_per_name = 1 @@ -96,6 +98,9 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: logging.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") return self.count_tokens(text, model="gpt-3.5-turbo-0613") + elif "gpt-4o" in model: + logging.info("gpt-4o may update over time. Returning num tokens assuming gpt-4o-2024-05-13.") + return self.count_tokens(text, model="gpt-4o-2024-05-13") elif "gpt-4" in model: logging.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return self.count_tokens(text, model="gpt-4-0613") diff --git a/poetry.lock b/poetry.lock index 072338a59..14f5e2bff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3292,6 +3292,7 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -3300,6 +3301,8 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -4691,47 +4694,47 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] [[package]] name = "tiktoken" -version = "0.5.2" +version = "0.7.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false python-versions = ">=3.8" files = [ - {file = "tiktoken-0.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c4e654282ef05ec1bd06ead22141a9a1687991cef2c6a81bdd1284301abc71d"}, - {file = "tiktoken-0.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7b3134aa24319f42c27718c6967f3c1916a38a715a0fa73d33717ba121231307"}, - {file = "tiktoken-0.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6092e6e77730929c8c6a51bb0d7cfdf1b72b63c4d033d6258d1f2ee81052e9e5"}, - {file = "tiktoken-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72ad8ae2a747622efae75837abba59be6c15a8f31b4ac3c6156bc56ec7a8e631"}, - {file = "tiktoken-0.5.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:51cba7c8711afa0b885445f0637f0fcc366740798c40b981f08c5f984e02c9d1"}, - {file = "tiktoken-0.5.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3d8c7d2c9313f8e92e987d585ee2ba0f7c40a0de84f4805b093b634f792124f5"}, - {file = "tiktoken-0.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:692eca18c5fd8d1e0dde767f895c17686faaa102f37640e884eecb6854e7cca7"}, - {file = "tiktoken-0.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:138d173abbf1ec75863ad68ca289d4da30caa3245f3c8d4bfb274c4d629a2f77"}, - {file = "tiktoken-0.5.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7388fdd684690973fdc450b47dfd24d7f0cbe658f58a576169baef5ae4658607"}, - {file = "tiktoken-0.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a114391790113bcff670c70c24e166a841f7ea8f47ee2fe0e71e08b49d0bf2d4"}, - {file = "tiktoken-0.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca96f001e69f6859dd52926d950cfcc610480e920e576183497ab954e645e6ac"}, - {file = "tiktoken-0.5.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:15fed1dd88e30dfadcdd8e53a8927f04e1f6f81ad08a5ca824858a593ab476c7"}, - {file = "tiktoken-0.5.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f8e692db5756f7ea8cb0cfca34638316dcf0841fb8469de8ed7f6a015ba0b0"}, - {file = "tiktoken-0.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:bcae1c4c92df2ffc4fe9f475bf8148dbb0ee2404743168bbeb9dcc4b79dc1fdd"}, - {file = "tiktoken-0.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b76a1e17d4eb4357d00f0622d9a48ffbb23401dcf36f9716d9bd9c8e79d421aa"}, - {file = "tiktoken-0.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:01d8b171bb5df4035580bc26d4f5339a6fd58d06f069091899d4a798ea279d3e"}, - {file = "tiktoken-0.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42adf7d4fb1ed8de6e0ff2e794a6a15005f056a0d83d22d1d6755a39bffd9e7f"}, - {file = "tiktoken-0.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c3f894dbe0adb44609f3d532b8ea10820d61fdcb288b325a458dfc60fefb7db"}, - {file = "tiktoken-0.5.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:58ccfddb4e62f0df974e8f7e34a667981d9bb553a811256e617731bf1d007d19"}, - {file = "tiktoken-0.5.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58902a8bad2de4268c2a701f1c844d22bfa3cbcc485b10e8e3e28a050179330b"}, - {file = "tiktoken-0.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:5e39257826d0647fcac403d8fa0a474b30d02ec8ffc012cfaf13083e9b5e82c5"}, - {file = "tiktoken-0.5.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8bde3b0fbf09a23072d39c1ede0e0821f759b4fa254a5f00078909158e90ae1f"}, - {file = "tiktoken-0.5.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2ddee082dcf1231ccf3a591d234935e6acf3e82ee28521fe99af9630bc8d2a60"}, - {file = "tiktoken-0.5.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35c057a6a4e777b5966a7540481a75a31429fc1cb4c9da87b71c8b75b5143037"}, - {file = "tiktoken-0.5.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c4a049b87e28f1dc60509f8eb7790bc8d11f9a70d99b9dd18dfdd81a084ffe6"}, - {file = "tiktoken-0.5.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5bf5ce759089f4f6521ea6ed89d8f988f7b396e9f4afb503b945f5c949c6bec2"}, - {file = "tiktoken-0.5.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0c964f554af1a96884e01188f480dad3fc224c4bbcf7af75d4b74c4b74ae0125"}, - {file = "tiktoken-0.5.2-cp38-cp38-win_amd64.whl", hash = "sha256:368dd5726d2e8788e47ea04f32e20f72a2012a8a67af5b0b003d1e059f1d30a3"}, - {file = "tiktoken-0.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a2deef9115b8cd55536c0a02c0203512f8deb2447f41585e6d929a0b878a0dd2"}, - {file = "tiktoken-0.5.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2ed7d380195affbf886e2f8b92b14edfe13f4768ff5fc8de315adba5b773815e"}, - {file = "tiktoken-0.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c76fce01309c8140ffe15eb34ded2bb94789614b7d1d09e206838fc173776a18"}, - {file = "tiktoken-0.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60a5654d6a2e2d152637dd9a880b4482267dfc8a86ccf3ab1cec31a8c76bfae8"}, - {file = "tiktoken-0.5.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:41d4d3228e051b779245a8ddd21d4336f8975563e92375662f42d05a19bdff41"}, - {file = "tiktoken-0.5.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c1cdec2c92fcde8c17a50814b525ae6a88e8e5b02030dc120b76e11db93f13"}, - {file = "tiktoken-0.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:84ddb36faedb448a50b246e13d1b6ee3437f60b7169b723a4b2abad75e914f3e"}, - {file = "tiktoken-0.5.2.tar.gz", hash = "sha256:f54c581f134a8ea96ce2023ab221d4d4d81ab614efa0b2fbce926387deb56c80"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79383a6e2c654c6040e5f8506f3750db9ddd71b550c724e673203b4f6b4b4590"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4511c52caacf3c4981d1ae2df85908bd31853f33d30b345c8b6830763f769c"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13c94efacdd3de9aff824a788353aa5749c0faee1fbe3816df365ea450b82311"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8e58c7eb29d2ab35a7a8929cbeea60216a4ccdf42efa8974d8e176d50c9a3df5"}, + {file = "tiktoken-0.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:21a20c3bd1dd3e55b91c1331bf25f4af522c525e771691adbc9a69336fa7f702"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10c7674f81e6e350fcbed7c09a65bca9356eaab27fb2dac65a1e440f2bcfe30f"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811229fde1652fedcca7c6dfe76724d0908775b353556d8a71ed74d866f73f7b"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b6e7dc2e7ad1b3757e8a24597415bafcfb454cebf9a33a01f2e6ba2e663992"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1063c5748be36344c7e18c7913c53e2cca116764c2080177e57d62c7ad4576d1"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:20295d21419bfcca092644f7e2f2138ff947a6eb8cfc732c09cc7d76988d4a89"}, + {file = "tiktoken-0.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:959d993749b083acc57a317cbc643fb85c014d055b2119b739487288f4e5d1cb"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:71c55d066388c55a9c00f61d2c456a6086673ab7dec22dd739c23f77195b1908"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09ed925bccaa8043e34c519fbb2f99110bd07c6fd67714793c21ac298e449410"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20b5c6af30e621b4aca094ee61777a44118f52d886dbe4f02b70dfe05c15350"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d427614c3e074004efa2f2411e16c826f9df427d3c70a54725cae860f09e4bf4"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c46d7af7b8c6987fac9b9f61041b452afe92eb087d29c9ce54951280f899a97"}, + {file = "tiktoken-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:0bc603c30b9e371e7c4c7935aba02af5994a909fc3c0fe66e7004070858d3f8f"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2398fecd38c921bcd68418675a6d155fad5f5e14c2e92fcf5fe566fa5485a858"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f5f6afb52fb8a7ea1c811e435e4188f2bef81b5e0f7a8635cc79b0eef0193d6"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:861f9ee616766d736be4147abac500732b505bf7013cfaf019b85892637f235e"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54031f95c6939f6b78122c0aa03a93273a96365103793a22e1793ee86da31685"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c72baaeaefa03ff9ba9688624143c858d1f6b755bb85d456d59e529e17234769"}, + {file = "tiktoken-0.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:131b8aeb043a8f112aad9f46011dced25d62629091e51d9dc1adbf4a1cc6aa98"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cabc6dc77460df44ec5b879e68692c63551ae4fae7460dd4ff17181df75f1db7"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8d57f29171255f74c0aeacd0651e29aa47dff6f070cb9f35ebc14c82278f3b25"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ee92776fdbb3efa02a83f968c19d4997a55c8e9ce7be821ceee04a1d1ee149c"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a81bac94769cab437dd3ab0b8a4bc4e0f9cf6835bcaa88de71f39af1791727a"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d6d73ea93e91d5ca771256dfc9d1d29f5a554b83821a1dc0891987636e0ae226"}, + {file = "tiktoken-0.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:2bcb28ddf79ffa424f171dfeef9a4daff61a94c631ca6813f43967cb263b83b9"}, + {file = "tiktoken-0.7.0.tar.gz", hash = "sha256:1077266e949c24e0291f6c350433c6f0971365ece2b173a23bc3b9f9defef6b6"}, ] [package.dependencies] @@ -5389,4 +5392,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "7aa1485db323176c7b372efd3483d060c469d18fdf0c6ed172bb3a82d4ab238b" +content-hash = "032d9be0951a4048eb10716e492ea1f91e978ceb1ee24e5d13f3251f37cefb0a" diff --git a/pyproject.toml b/pyproject.toml index 105591208..222ee3b8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "griptape" -version = "0.25.0" +version = "0.25.1" description = "Modular Python framework for LLM workflows, tools, memory, and data." authors = ["Griptape "] license = "Apache 2.0" @@ -18,7 +18,7 @@ attrs = ">=22" jinja2 = ">=3.1.3" marshmallow = ">=3" marshmallow-enum = ">=1.5" -tiktoken = ">=0.3" +tiktoken = ">=0.7" rich = ">=13" schema = ">=0.7" pyyaml = ">=6" diff --git a/tests/mocks/mock_event.py b/tests/mocks/mock_event.py index 651cf3ece..2b9d9ade3 100644 --- a/tests/mocks/mock_event.py +++ b/tests/mocks/mock_event.py @@ -3,4 +3,4 @@ class MockEvent(BaseEvent): def to_dict(self) -> dict: - return {"timestamp": self.timestamp} + return {"timestamp": self.timestamp, "id": self.id} diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index 3e0c173ca..dd54eeb73 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -6,4 +6,7 @@ @define class MockEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: - ... + pass + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + pass diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 60eeed091..a6df9330c 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -19,7 +19,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -72,7 +72,7 @@ def test_to_dict(self, config): "prompt_driver": { "base_url": None, "type": "OpenAiChatPromptDriver", - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -98,7 +98,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -113,7 +113,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -129,7 +129,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, diff --git a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py index e0a9e7b7c..706831d67 100644 --- a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py @@ -29,3 +29,6 @@ def test_init(self, driver): def test_try_publish_event_payload(self, driver): driver.try_publish_event_payload(MockEvent().to_dict()) + + def test_try_publish_event_payload_batch(self, driver): + driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)]) diff --git a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py index cd50ac82d..9a5fe9ec0 100644 --- a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py @@ -23,3 +23,6 @@ def test_init(self, driver): def test_try_publish_event_payload(self, driver): driver.try_publish_event_payload(MockEvent().to_dict()) + + def test_try_publish_event_payload_batch(self, driver): + driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)]) diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py new file mode 100644 index 000000000..6d33dd2a0 --- /dev/null +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -0,0 +1,25 @@ +from tests.mocks.mock_event import MockEvent +from tests.mocks.mock_event_listener_driver import MockEventListenerDriver + + +class TestBaseEventListenerDriver: + def test__safe_try_publish_event(self): + driver = MockEventListenerDriver(batched=False) + + for _ in range(4): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) + assert len(driver.batch) == 0 + + def test__safe_try_publish_event_batch(self): + driver = MockEventListenerDriver(batched=True) + + for _ in range(0, 3): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) + assert len(driver.batch) == 3 + + def test__safe_try_publish_event_batch_flush(self): + driver = MockEventListenerDriver(batched=True) + + for _ in range(0, 3): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=True) + assert len(driver.batch) == 0 diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index 51f29ff71..d27f09ec8 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -39,6 +39,17 @@ def test_try_publish_event_payload(self, mock_post, driver): headers={"Authorization": "Bearer foo bar"}, ) + def try_publish_event_payload_batch(self, mock_post, driver): + for _ in range(3): + event = MockEvent() + driver.try_publish_event_payload(event.to_dict()) + + mock_post.assert_called_with( + url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", + json=event.to_dict(), + headers={"Authorization": "Bearer foo bar"}, + ) + def test_no_structure_run_id(self): with pytest.raises(ValueError): GriptapeCloudEventListenerDriver(api_key="foo bar") diff --git a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py index f3f872c0a..50021cbe3 100644 --- a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py @@ -23,3 +23,14 @@ def test_try_publish_event_payload(self, mock_post): mock_post.assert_called_once_with( url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"} ) + + def test_try_publish_event_payload_batch(self, mock_post): + driver = WebhookEventListenerDriver(webhook_url="foo bar", headers={"Authorization": "Bearer foo bar"}) + + for _ in range(3): + event = MockEvent() + driver.try_publish_event_payload(event.to_dict()) + + mock_post.assert_called_with( + url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"} + ) diff --git a/tests/unit/drivers/vector/test_redis_vector_store_driver.py b/tests/unit/drivers/vector/test_redis_vector_store_driver.py index b5dfa2832..3c98180e7 100644 --- a/tests/unit/drivers/vector/test_redis_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_redis_vector_store_driver.py @@ -1,3 +1,4 @@ +from unittest.mock import MagicMock import pytest import redis from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -6,19 +7,21 @@ class TestRedisVectorStorageDriver: @pytest.fixture(autouse=True) - def mock_redis(self, mocker): - fake_hgetall_response = {b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", b"metadata": b'{"foo": "bar"}'} + def mock_client(self, mocker): + return mocker.patch("redis.Redis").return_value - mocker.patch.object(redis.StrictRedis, "hset", return_value=None) - mocker.patch.object(redis.StrictRedis, "hgetall", return_value=fake_hgetall_response) - mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"some_namespace:some_vector_id"]) - - fake_redisearch = mocker.MagicMock() - fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[])) - fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found")) - fake_redisearch.create_index = mocker.MagicMock(return_value=None) + @pytest.fixture + def mock_keys(self, mock_client): + mock_client.keys.return_value = [b"some_vector_id"] + return mock_client.keys - mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch) + @pytest.fixture + def mock_hgetall(self, mock_client): + mock_client.hgetall.return_value = { + b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", + b"metadata": b'{"foo": "bar"}', + } + return mock_client.hgetall @pytest.fixture def driver(self): @@ -26,23 +29,70 @@ def driver(self): host="localhost", port=6379, index="test_index", db=0, embedding_driver=MockEmbeddingDriver() ) + @pytest.fixture + def mock_search(self, mock_client): + mock_client.ft.return_value.search.return_value.docs = [ + MagicMock( + id="some_namespace:some_vector_id", + score="0.456198036671", + metadata='{"foo": "bar"}', + vec_string="[1.0, 2.0, 3.0]", + ) + ] + return mock_client.ft.return_value.search + def test_upsert_vector(self, driver): assert ( driver.upsert_vector([1.0, 2.0, 3.0], vector_id="some_vector_id", namespace="some_namespace") == "some_vector_id" ) - def test_load_entry(self, driver): + def test_load_entry(self, driver, mock_hgetall): + entry = driver.load_entry("some_vector_id") + mock_hgetall.assert_called_once_with("some_vector_id") + assert entry.id == "some_vector_id" + assert entry.vector == [1.0, 2.0, 3.0] + assert entry.meta == {"foo": "bar"} + + def test_load_entry_with_namespace(self, driver, mock_hgetall): entry = driver.load_entry("some_vector_id", namespace="some_namespace") + mock_hgetall.assert_called_once_with("some_namespace:some_vector_id") assert entry.id == "some_vector_id" assert entry.vector == [1.0, 2.0, 3.0] assert entry.meta == {"foo": "bar"} - def test_load_entries(self, driver): + def test_load_entries(self, driver, mock_keys, mock_hgetall): + entries = driver.load_entries() + mock_keys.assert_called_once_with("*") + mock_hgetall.assert_called_once_with("some_vector_id") + assert len(entries) == 1 + assert entries[0].vector == [1.0, 2.0, 3.0] + assert entries[0].meta == {"foo": "bar"} + + def test_load_entries_with_namespace(self, driver, mock_keys, mock_hgetall): entries = driver.load_entries(namespace="some_namespace") + mock_keys.assert_called_once_with("some_namespace:*") + mock_hgetall.assert_called_once_with("some_namespace:some_vector_id") assert len(entries) == 1 assert entries[0].vector == [1.0, 2.0, 3.0] assert entries[0].meta == {"foo": "bar"} - def test_query(self, driver): - assert driver.query("some_vector_id") == [] + def test_query(self, driver, mock_search): + results = driver.query("Some query") + mock_search.assert_called_once() + assert len(results) == 1 + assert results[0].namespace == "some_namespace" + assert results[0].id == "some_vector_id" + assert results[0].score == 0.456198036671 + assert results[0].meta == {"foo": "bar"} + assert results[0].vector is None + + def test_query_with_include_vectors(self, driver, mock_search): + results = driver.query("Some query", include_vectors=True) + mock_search.assert_called_once() + assert len(results) == 1 + assert results[0].namespace == "some_namespace" + assert results[0].id == "some_vector_id" + assert results[0].score == 0.456198036671 + assert results[0].meta == {"foo": "bar"} + assert results[0].vector == [1.0, 2.0, 3.0] diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index fcc9688ed..2f32837e0 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -114,7 +114,7 @@ def event_handler(_: BaseEvent): event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) event_listener.publish_event(mock_event) - mock_event_listener_driver.publish_event.assert_called_once_with(mock_event) + mock_event_listener_driver.publish_event.assert_called_once_with(mock_event, flush=False) def test_publish_transformed_event(self): mock_event_listener_driver = Mock() @@ -127,4 +127,4 @@ def event_handler(event: BaseEvent): event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) event_listener.publish_event(mock_event) - mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}) + mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False) diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index a1e697346..8b231d85f 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -60,6 +60,16 @@ def test_with_no_action_input(self): assert json_dict[0].get("input") is None def test_no_actions(self): + valid_input = "Thought: need to test\n" "<|Response|>: test observation\n" "Answer: test output" + + task = ToolkitTask(tools=[MockTool()]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(valid_input)) + json_dict = json.loads(subtask.actions_to_json()) + + assert len(json_dict) == 0 + + def test_empty_actions(self): valid_input = "Thought: need to test\n" "Actions: []\n" "<|Response|>: test observation\n" "Answer: test output" task = ToolkitTask(tools=[MockTool()]) @@ -67,7 +77,17 @@ def test_no_actions(self): subtask = task.add_subtask(ActionsSubtask(valid_input)) json_dict = json.loads(subtask.actions_to_json()) + assert len(json_dict) == 0 + + def test_invalid_actions(self): + invalid_input = ( + "Thought: need to test\n" "Actions: [{,{]\n" "<|Response|>: test observation\n" "Answer: test output" + ) + + task = ToolkitTask(tools=[MockTool()]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(invalid_input)) + json_dict = json.loads(subtask.actions_to_json()) + assert json_dict[0]["name"] == "error" - assert json_dict[0]["input"] == { - "error": "Action JSON validation error: Array item count 0 is less than minimum count of 1." - } + assert "Action input parsing error" in json_dict[0]["input"]["error"]