From 164ffde0a920266a74139966d9b6472464728da1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 27 Aug 2024 11:22:51 -0700 Subject: [PATCH] Refactor Artifacts --- .ignore | 0 CHANGELOG.md | 2 + docs/griptape-framework/data/artifacts.md | 7 --- griptape/artifacts/__init__.py | 17 ++++-- griptape/artifacts/action_artifact.py | 7 ++- griptape/artifacts/audio_artifact.py | 18 +++++-- griptape/artifacts/base_artifact.py | 22 ++------ griptape/artifacts/base_system_artifact.py | 10 ++++ griptape/artifacts/blob_artifact.py | 26 ++++----- griptape/artifacts/boolean_artifact.py | 31 ----------- griptape/artifacts/csv_row_artifact.py | 19 ++++--- griptape/artifacts/error_artifact.py | 7 +-- griptape/artifacts/generic_artifact.py | 4 +- griptape/artifacts/image_artifact.py | 30 +++++++---- griptape/artifacts/info_artifact.py | 7 +-- griptape/artifacts/json_artifact.py | 23 ++++---- griptape/artifacts/list_artifact.py | 24 +++++---- griptape/artifacts/media_artifact.py | 53 ------------------- griptape/artifacts/table_artifact.py | 41 ++++++++++++++ griptape/artifacts/text_artifact.py | 17 +++--- .../amazon_bedrock_image_generation_driver.py | 12 ++--- ...ngface_pipeline_image_generation_driver.py | 4 +- .../leonardo_image_generation_driver.py | 12 +++-- .../openai_image_generation_driver.py | 3 +- griptape/loaders/csv_loader.py | 21 +++----- griptape/loaders/dataframe_loader.py | 20 +++---- griptape/loaders/sql_loader.py | 20 +++---- .../media_artifact_file_output_mixin.py | 6 +-- griptape/tasks/base_image_generation_task.py | 5 +- griptape/templates/memory/tool.j2 | 2 +- griptape/tools/query/tool.py | 4 +- griptape/tools/sql/tool.py | 7 +-- tests/unit/artifacts/test_action_artifact.py | 4 -- tests/unit/artifacts/test_audio_artifact.py | 14 ++--- tests/unit/artifacts/test_base_artifact.py | 8 ++- .../artifacts/test_base_media_artifact.py | 30 ----------- tests/unit/artifacts/test_blob_artifact.py | 17 ++---- tests/unit/artifacts/test_boolean_artifact.py | 37 ------------- tests/unit/artifacts/test_csv_row_artifact.py | 6 --- tests/unit/artifacts/test_image_artifact.py | 13 +++-- tests/unit/artifacts/test_json_artifact.py | 8 +-- ...table_diffusion_image_generation_driver.py | 4 +- ...st_azure_openai_image_generation_driver.py | 9 ++-- .../test_leonardo_image_generation_driver.py | 4 +- .../test_openai_image_generation_driver.py | 10 ++-- tests/unit/loaders/test_audio_loader.py | 6 +-- tests/unit/loaders/test_csv_loader.py | 50 ++++++++++------- tests/unit/loaders/test_dataframe_loader.py | 22 ++++---- tests/unit/loaders/test_image_loader.py | 19 +++---- tests/unit/loaders/test_sql_loader.py | 16 +++--- .../test_inpainting_image_generation_tool.py | 8 +-- .../test_outpainting_image_variation_tool.py | 8 +-- .../test_prompt_image_generation_tool.py | 5 +- tests/unit/tools/test_sql_tool.py | 2 +- tests/unit/tools/test_text_to_speech_tool.py | 3 +- .../test_variation_image_generation_tool.py | 4 +- 56 files changed, 335 insertions(+), 453 deletions(-) create mode 100644 .ignore create mode 100644 griptape/artifacts/base_system_artifact.py delete mode 100644 griptape/artifacts/boolean_artifact.py delete mode 100644 griptape/artifacts/media_artifact.py create mode 100644 griptape/artifacts/table_artifact.py delete mode 100644 tests/unit/artifacts/test_base_media_artifact.py delete mode 100644 tests/unit/artifacts/test_boolean_artifact.py diff --git a/.ignore b/.ignore new file mode 100644 index 0000000000..e69de29bb2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 555306f90d..04ec6e3ed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Parameter `meta: dict` on `BaseEvent`. +- `TableArtifact` for storing CSV data. ### Changed - **BREAKING**: Parameter `driver` on `BaseConversationMemory` renamed to `conversation_memory_driver`. @@ -15,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], Optional[dict]]`. - **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: Optional[dict]` as input. - **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`. +- **BREAKING**: `CsvLoader` now returns a `TableArtifact` instead of a `list[CsvRowArtifact]`. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. ### Fixed diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index 8c4da02b32..634496c9a3 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -46,13 +46,6 @@ An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blob). -## Boolean - -A [BooleanArtifact](../../reference/griptape/artifacts/boolean_artifact.md) is used for passing boolean values around the framework. - -!!! info - Any object passed on init to `BooleanArtifact` will be coerced into a `bool` type. This might lead to unintended behavior: `BooleanArtifact("False").value is True`. Use [BooleanArtifact.parse_bool](../../reference/griptape/artifacts/boolean_artifact.md#griptape.artifacts.boolean_artifact.BooleanArtifact.parse_bool) to convert case-insensitive string literal values `"True"` and `"False"` into a `BooleanArtifact`: `BooleanArtifact.parse_bool("False").value is False`. - ## Generic A [GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md) can be used as an escape hatch for passing any type of data around the framework. diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index f39bfea8d0..fae411d3a4 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,32 +1,39 @@ from .base_artifact import BaseArtifact + +from .base_system_artifact import BaseSystemArtifact from .error_artifact import ErrorArtifact from .info_artifact import InfoArtifact + from .text_artifact import TextArtifact from .json_artifact import JsonArtifact -from .blob_artifact import BlobArtifact -from .boolean_artifact import BooleanArtifact from .csv_row_artifact import CsvRowArtifact +from .table_artifact import TableArtifact + from .list_artifact import ListArtifact -from .media_artifact import MediaArtifact + +from .blob_artifact import BlobArtifact + from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact + from .action_artifact import ActionArtifact + from .generic_artifact import GenericArtifact __all__ = [ "BaseArtifact", + "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", "JsonArtifact", "BlobArtifact", - "BooleanArtifact", "CsvRowArtifact", "ListArtifact", - "MediaArtifact", "ImageArtifact", "AudioArtifact", "ActionArtifact", "GenericArtifact", + "TableArtifact", ] diff --git a/griptape/artifacts/action_artifact.py b/griptape/artifacts/action_artifact.py index a10653078a..8c438c6d98 100644 --- a/griptape/artifacts/action_artifact.py +++ b/griptape/artifacts/action_artifact.py @@ -5,15 +5,14 @@ from attrs import define, field from griptape.artifacts import BaseArtifact -from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.common import ToolAction @define() -class ActionArtifact(BaseArtifact, SerializableMixin): +class ActionArtifact(BaseArtifact): value: ToolAction = field(metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> ActionArtifact: - raise NotImplementedError + def to_text(self) -> str: + return str(self.value) diff --git a/griptape/artifacts/audio_artifact.py b/griptape/artifacts/audio_artifact.py index 3dc67fa366..c11db6c46a 100644 --- a/griptape/artifacts/audio_artifact.py +++ b/griptape/artifacts/audio_artifact.py @@ -1,12 +1,20 @@ from __future__ import annotations -from attrs import define +from attrs import define, field -from griptape.artifacts import MediaArtifact +from griptape.artifacts import BaseArtifact @define -class AudioArtifact(MediaArtifact): - """AudioArtifact is a type of MediaArtifact representing audio.""" +class AudioArtifact(BaseArtifact): + """AudioArtifact is a type of Artifact representing audio.""" - media_type: str = "audio" + value: bytes = field(metadata={"serializable": True}) + format: str = field(kw_only=True, metadata={"serializable": True}) + + @property + def mime_type(self) -> str: + return f"audio/{self.format}" + + def to_text(self) -> str: + return f"Audio, format: {self.format}, size: {len(self.value)} bytes" diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index d1e0d34f4a..0e967dafc9 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import uuid from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional @@ -25,22 +24,6 @@ class BaseArtifact(SerializableMixin, ABC): ) value: Any = field() - @classmethod - def value_to_bytes(cls, value: Any) -> bytes: - if isinstance(value, bytes): - return value - else: - return str(value).encode() - - @classmethod - def value_to_dict(cls, value: Any) -> dict: - dict_value = value if isinstance(value, dict) else json.loads(value) - - return dict(dict_value.items()) - - def to_text(self) -> str: - return str(self.value) - def __str__(self) -> str: return self.to_text() @@ -50,5 +33,8 @@ def __bool__(self) -> bool: def __len__(self) -> int: return len(self.value) + def to_bytes(self) -> bytes: + return self.to_text().encode("utf-8") + @abstractmethod - def __add__(self, other: BaseArtifact) -> BaseArtifact: ... + def to_text(self) -> str: ... diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py new file mode 100644 index 0000000000..c71eff57c0 --- /dev/null +++ b/griptape/artifacts/base_system_artifact.py @@ -0,0 +1,10 @@ +from abc import ABC + +from griptape.artifacts import BaseArtifact + + +class BaseSystemArtifact(BaseArtifact, ABC): + """Base class for Artifacts specific to Griptape.""" + + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 0c0dcc1223..75940eab82 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -1,26 +1,28 @@ from __future__ import annotations -import os.path -from typing import Optional +from typing import Any -from attrs import define, field +from attrs import Converter, define, field from griptape.artifacts import BaseArtifact @define class BlobArtifact(BaseArtifact): - value: bytes = field(converter=BaseArtifact.value_to_bytes, metadata={"serializable": True}) - dir_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + value: bytes = field( + converter=Converter(lambda value: BlobArtifact.value_to_bytes(value)), + metadata={"serializable": True}, + ) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) - - def __add__(self, other: BaseArtifact) -> BlobArtifact: - return BlobArtifact(self.value + other.value, name=self.name) - - @property - def full_path(self) -> str: - return os.path.join(self.dir_name, self.name) if self.dir_name else self.name + media_type: str = field(default="application/octet-stream", kw_only=True) + + @classmethod + def value_to_bytes(cls, value: Any) -> bytes: + if isinstance(value, bytes): + return value + else: + return str(value).encode() def to_text(self) -> str: return self.value.decode(encoding=self.encoding, errors=self.encoding_error_handler) diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py deleted file mode 100644 index 5bcdfac9ba..0000000000 --- a/griptape/artifacts/boolean_artifact.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from typing import Union - -from attrs import define, field - -from griptape.artifacts import BaseArtifact - - -@define -class BooleanArtifact(BaseArtifact): - value: bool = field(converter=bool, metadata={"serializable": True}) - - @classmethod - def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001 - """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" - if value is not None: - if isinstance(value, str): - if value.lower() == "true": - return BooleanArtifact(True) # noqa: FBT003 - elif value.lower() == "false": - return BooleanArtifact(False) # noqa: FBT003 - elif isinstance(value, bool): - return BooleanArtifact(value) - raise ValueError(f"Cannot convert '{value}' to BooleanArtifact") - - def __add__(self, other: BaseArtifact) -> BooleanArtifact: - raise ValueError("Cannot add BooleanArtifact with other artifacts") - - def __eq__(self, value: object) -> bool: - return self.value is value diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py index c4347099e8..c36158ecdd 100644 --- a/griptape/artifacts/csv_row_artifact.py +++ b/griptape/artifacts/csv_row_artifact.py @@ -2,23 +2,30 @@ import csv import io +import json +from typing import Any -from attrs import define, field +from attrs import Converter, define, field -from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.artifacts import TextArtifact @define class CsvRowArtifact(TextArtifact): - value: dict[str, str] = field(converter=BaseArtifact.value_to_dict, metadata={"serializable": True}) + value: dict[str, str] = field( + converter=Converter(lambda value: CsvRowArtifact.value_to_dict(value)), metadata={"serializable": True} + ) delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> CsvRowArtifact: - return CsvRowArtifact(self.value | other.value) - def __bool__(self) -> bool: return len(self) > 0 + @classmethod + def value_to_dict(cls, value: Any) -> dict: + dict_value = value if isinstance(value, dict) else json.loads(value) + + return dict(dict_value.items()) + def to_text(self) -> str: with io.StringIO() as csvfile: writer = csv.DictWriter( diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index d065d754b2..13ba4497c6 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,13 +4,10 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define -class ErrorArtifact(BaseArtifact): +class ErrorArtifact(BaseSystemArtifact): value: str = field(converter=str, metadata={"serializable": True}) exception: Optional[Exception] = field(default=None, kw_only=True, metadata={"serializable": False}) - - def __add__(self, other: BaseArtifact) -> ErrorArtifact: - return ErrorArtifact(self.value + other.value) diff --git a/griptape/artifacts/generic_artifact.py b/griptape/artifacts/generic_artifact.py index 8e0b7e38c2..be7fee7f93 100644 --- a/griptape/artifacts/generic_artifact.py +++ b/griptape/artifacts/generic_artifact.py @@ -11,5 +11,5 @@ class GenericArtifact(BaseArtifact): value: Any = field(metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - raise NotImplementedError + def to_text(self) -> str: + return str(self.value) diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index e963b38818..1705f030c3 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -1,23 +1,35 @@ from __future__ import annotations +import base64 + from attrs import define, field -from griptape.artifacts import MediaArtifact +from griptape.artifacts import BaseArtifact @define -class ImageArtifact(MediaArtifact): - """ImageArtifact is a type of MediaArtifact representing an image. +class ImageArtifact(BaseArtifact): + """ImageArtifact is a type of Artifact representing an image. Attributes: value: Raw bytes representing media data. - media_type: The type of media, defaults to "image". - format: The format of the media, like png, jpeg, or gif. - name: Artifact name, generated using creation time and a random string. - model: Optionally specify the model used to generate the media. - prompt: Optionally specify the prompt used to generate the media. + format: The format of the media, like png, jpeg, or gif. Default is png. + width: The width of the image in pixels. + height: The height of the image in pixels. """ - media_type: str = "image" + value: bytes = field(metadata={"serializable": True}) + format: str = field(default="png", kw_only=True, metadata={"serializable": True}) width: int = field(kw_only=True, metadata={"serializable": True}) height: int = field(kw_only=True, metadata={"serializable": True}) + + @property + def base64(self) -> str: + return base64.b64encode(self.value).decode("utf-8") + + @property + def mime_type(self) -> str: + return f"image/{self.format}" + + def to_text(self) -> str: + return self.base64 diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 26fe6366bc..19b67f7f06 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,12 +2,9 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define -class InfoArtifact(BaseArtifact): +class InfoArtifact(BaseSystemArtifact): value: str = field(converter=str, metadata={"serializable": True}) - - def __add__(self, other: BaseArtifact) -> InfoArtifact: - return InfoArtifact(self.value + other.value) diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index b292879a9c..02bef7611f 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -1,21 +1,24 @@ from __future__ import annotations import json -from typing import Union +from typing import Any, Union -from attrs import define, field +from attrs import Converter, define, field -from griptape.artifacts import BaseArtifact - -Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] +from griptape.artifacts.text_artifact import TextArtifact @define -class JsonArtifact(BaseArtifact): - value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True}) +class JsonArtifact(TextArtifact): + Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] + + value: Json = field( + converter=Converter(lambda value: JsonArtifact.value_to_dict(value)), metadata={"serializable": True} + ) + + @classmethod + def value_to_dict(cls, value: Any) -> dict: + return json.loads(json.dumps(value)) def to_text(self) -> str: return json.dumps(self.value) - - def __add__(self, other: BaseArtifact) -> JsonArtifact: - raise NotImplementedError diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 298f29c6ad..9bf6f0978d 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,18 +4,29 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact if TYPE_CHECKING: from collections.abc import Sequence + from griptape.artifacts import BaseArtifact + @define -class ListArtifact(BaseArtifact): +class ListArtifact(BaseSystemArtifact): value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + def __getitem__(self, key: int) -> BaseArtifact: + return self.value[key] + + def __bool__(self) -> bool: + return len(self) > 0 + + def __add__(self, other: BaseArtifact) -> ListArtifact: + return ListArtifact(self.value + other.value) + @value.validator # pyright: ignore[reportAttributeAccessIssue] def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None: if self.validate_uniform_types and len(value) > 0: @@ -31,18 +42,9 @@ def child_type(self) -> Optional[type]: else: return None - def __getitem__(self, key: int) -> BaseArtifact: - return self.value[key] - - def __bool__(self) -> bool: - return len(self) > 0 - def to_text(self) -> str: return self.item_separator.join([v.to_text() for v in self.value]) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - return ListArtifact(self.value + other.value) - def is_type(self, target_type: type) -> bool: if self.value: return isinstance(self.value[0], target_type) diff --git a/griptape/artifacts/media_artifact.py b/griptape/artifacts/media_artifact.py deleted file mode 100644 index a57217fc74..0000000000 --- a/griptape/artifacts/media_artifact.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import base64 -import random -import string -import time -from typing import Optional - -from attrs import define, field - -from griptape.artifacts import BlobArtifact - - -@define -class MediaArtifact(BlobArtifact): - """MediaArtifact is a type of BlobArtifact that represents media (image, audio, video, etc.) and can be extended to support a specific media type. - - Attributes: - value: Raw bytes representing media data. - media_type: The type of media, like image, audio, or video. - format: The format of the media, like png, wav, or mp4. - name: Artifact name, generated using creation time and a random string. - model: Optionally specify the model used to generate the media. - prompt: Optionally specify the prompt used to generate the media. - """ - - media_type: str = field(default="media", kw_only=True, metadata={"serializable": True}) - format: str = field(kw_only=True, metadata={"serializable": True}) - model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - - def __attrs_post_init__(self) -> None: - # Generating the name string requires attributes set by child classes. - # This waits until all attributes are available before generating a name. - if self.name == self.id: - self.name = self.make_name() - - @property - def mime_type(self) -> str: - return f"{self.media_type}/{self.format}" - - @property - def base64(self) -> str: - return base64.b64encode(self.value).decode("utf-8") - - def to_text(self) -> str: - return f"Media, type: {self.mime_type}, size: {len(self.value)} bytes" - - def make_name(self) -> str: - entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) - fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime()) - - return f"{self.media_type}_artifact_{fmt_time}_{entropy}.{self.format}" diff --git a/griptape/artifacts/table_artifact.py b/griptape/artifacts/table_artifact.py new file mode 100644 index 0000000000..d3ca003fa6 --- /dev/null +++ b/griptape/artifacts/table_artifact.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import csv +import io +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +from griptape.artifacts.text_artifact import TextArtifact + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@define +class TableArtifact(TextArtifact): + value: list[dict] = field(factory=list, metadata={"serializable": True}) + delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) + fieldnames: Optional[Sequence[str]] = field(factory=list, metadata={"serializable": True}) + quoting: int = field(default=csv.QUOTE_MINIMAL, kw_only=True, metadata={"serializable": True}) + line_terminator: str = field(default="\n", kw_only=True, metadata={"serializable": True}) + + def __bool__(self) -> bool: + return len(self.value) > 0 + + def to_text(self) -> str: + with io.StringIO() as csvfile: + fieldnames = (self.value[0].keys() if self.value else []) if self.fieldnames is None else self.fieldnames + + writer = csv.DictWriter( + csvfile, + fieldnames=fieldnames, + quoting=self.quoting, + delimiter=self.delimiter, + lineterminator=self.line_terminator, + ) + + writer.writeheader() + writer.writerows(self.value) + + return csvfile.getvalue().strip() diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 752f666155..5ccca96688 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -18,16 +18,22 @@ class TextArtifact(BaseArtifact): encoding_error_handler: str = field(default="strict", kw_only=True) _embedding: list[float] = field(factory=list, kw_only=True) - @property - def embedding(self) -> Optional[list[float]]: - return None if len(self._embedding) == 0 else self._embedding - def __add__(self, other: BaseArtifact) -> TextArtifact: return TextArtifact(self.value + other.value) def __bool__(self) -> bool: return bool(self.value.strip()) + @property + def embedding(self) -> Optional[list[float]]: + return None if len(self._embedding) == 0 else self._embedding + + def to_text(self) -> str: + return self.value + + def to_bytes(self) -> bytes: + return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) + def generate_embedding(self, driver: BaseEmbeddingDriver) -> Optional[list[float]]: self._embedding.clear() self._embedding.extend(driver.embed_string(str(self.value))) @@ -36,6 +42,3 @@ def generate_embedding(self, driver: BaseEmbeddingDriver) -> Optional[list[float def token_count(self, tokenizer: BaseTokenizer) -> int: return tokenizer.count_tokens(str(self.value)) - - def to_bytes(self) -> bytes: - return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index 7106c81926..4db302f6f6 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -46,12 +46,11 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=self.image_width, height=self.image_height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_variation( @@ -70,12 +69,11 @@ def try_image_variation( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_inpainting( @@ -96,12 +94,11 @@ def try_image_inpainting( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_outpainting( @@ -122,12 +119,11 @@ def try_image_outpainting( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def _make_request(self, request: dict) -> bytes: diff --git a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py index 46dbcd331c..b89df1c4b5 100644 --- a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py +++ b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py @@ -44,7 +44,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ format=self.output_format.lower(), height=output_image.height, width=output_image.width, - prompt=prompt, + meta={"prompt": prompt}, ) def try_image_variation( @@ -76,7 +76,7 @@ def try_image_variation( format=self.output_format.lower(), height=output_image.height, width=output_image.width, - prompt=prompt, + meta={"prompt": prompt}, ) def try_image_inpainting( diff --git a/griptape/drivers/image_generation/leonardo_image_generation_driver.py b/griptape/drivers/image_generation/leonardo_image_generation_driver.py index e32dbb4c72..db89244bf5 100644 --- a/griptape/drivers/image_generation/leonardo_image_generation_driver.py +++ b/griptape/drivers/image_generation/leonardo_image_generation_driver.py @@ -60,8 +60,10 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ format="png", width=self.image_width, height=self.image_height, - model=self.model, - prompt=", ".join(prompts), + meta={ + "model": self.model, + "prompt": ", ".join(prompts), + }, ) def try_image_variation( @@ -87,8 +89,10 @@ def try_image_variation( format="png", width=self.image_width, height=self.image_height, - model=self.model, - prompt=", ".join(prompts), + meta={ + "model": self.model, + "prompt": ", ".join(prompts), + }, ) def try_image_outpainting( diff --git a/griptape/drivers/image_generation/openai_image_generation_driver.py b/griptape/drivers/image_generation/openai_image_generation_driver.py index 0ee50a1e2c..bf77ac300b 100644 --- a/griptape/drivers/image_generation/openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/openai_image_generation_driver.py @@ -151,6 +151,5 @@ def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageA format="png", width=image_dimensions[0], height=image_dimensions[1], - model=self.model, - prompt=prompt, + meta={"model": self.model, "prompt": prompt}, ) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 14dfe3e4a6..b54f0d4be2 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,7 +6,7 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TableArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -19,33 +19,28 @@ class CsvLoader(BaseLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: - artifacts = [] - + def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact: if isinstance(source, bytes): source = source.decode(encoding=self.encoding) elif isinstance(source, (bytearray, memoryview)): raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [CsvRowArtifact(row) for row in reader] - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) + artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) - for chunk in chunks: - artifacts.append(chunk) + if self.embedding_driver: + artifact.generate_embedding(self.embedding_driver) - return artifacts + return artifact def load_collection( self, sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[CsvRowArtifact]]: + ) -> dict[str, TableArtifact]: return cast( - dict[str, list[CsvRowArtifact]], + dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 0b1ae14484..5fbbd51d16 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TableArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -19,22 +19,16 @@ class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: - artifacts = [] - - chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] + def load(self, source: DataFrame, *args, **kwargs) -> TableArtifact: + artifact = TableArtifact(list(source.to_dict(orient="records"))) if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) + artifact.generate_embedding(self.embedding_driver) - return artifacts + return artifact - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, TableArtifact]: + return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index e4522796ff..d723a5a9fb 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TableArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -16,20 +16,14 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: str, *args, **kwargs) -> TableArtifact: rows = self.sql_driver.execute_query(source) - artifacts = [] - - chunks = [CsvRowArtifact(row.cells) for row in rows] if rows else [] + artifact = TableArtifact([row.cells for row in rows] if rows else []) if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) + artifact.generate_embedding(self.embedding_driver) - return artifacts + return artifact - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, TableArtifact]: + return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/mixins/media_artifact_file_output_mixin.py b/griptape/mixins/media_artifact_file_output_mixin.py index 9b9f349115..8dd4efdb09 100644 --- a/griptape/mixins/media_artifact_file_output_mixin.py +++ b/griptape/mixins/media_artifact_file_output_mixin.py @@ -7,7 +7,7 @@ from attrs import Attribute, define, field if TYPE_CHECKING: - from griptape.artifacts import BlobArtifact + from griptape.artifacts import BaseArtifact @define(slots=False) @@ -31,7 +31,7 @@ def validate_output_file(self, _: Attribute, output_file: str) -> None: if self.output_dir: raise ValueError("Can't have both output_dir and output_file specified.") - def _write_to_file(self, artifact: BlobArtifact) -> None: + def _write_to_file(self, artifact: BaseArtifact) -> None: if self.output_file: outfile = self.output_file elif self.output_dir: @@ -42,4 +42,4 @@ def _write_to_file(self, artifact: BlobArtifact) -> None: if os.path.dirname(outfile): os.makedirs(os.path.dirname(outfile), exist_ok=True) - Path(outfile).write_bytes(artifact.value) + Path(outfile).write_bytes(artifact.to_bytes()) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index f0c1f0e7e5..3c87c2f13f 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -15,8 +15,7 @@ from griptape.tasks import BaseTask if TYPE_CHECKING: - from griptape.artifacts import MediaArtifact - + from griptape.artifacts import ImageArtifact logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -64,6 +63,6 @@ def all_negative_rulesets(self) -> list[Ruleset]: return task_rulesets - def _read_from_file(self, path: str) -> MediaArtifact: + def _read_from_file(self, path: str) -> ImageArtifact: logger.info("Reading image from %s", os.path.abspath(path)) return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/templates/memory/tool.j2 b/griptape/templates/memory/tool.j2 index d58b26f5d7..10214b91d0 100644 --- a/griptape/templates/memory/tool.j2 +++ b/griptape/templates/memory/tool.j2 @@ -1 +1 @@ -Output of "{{ tool_name }}.{{ activity_name }}" was stored in memory with memory_name "{{ memory_name }}" and artifact_namespace "{{ artifact_namespace }}" \ No newline at end of file +Output of "{{ tool_name }}.{{ activity_name }}" was stored in memory with memory_name "{{ memory_name }}" and artifact_namespace "{{ artifact_namespace }}" diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 0089970e91..0274e7940d 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from schema import Literal, Or, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact from griptape.configs import Defaults from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -60,7 +60,7 @@ class QueryTool(BaseTool, RuleMixin): ), }, ) - def query(self, params: dict) -> BaseArtifact: + def query(self, params: dict) -> ListArtifact | ErrorArtifact: query = params["values"]["query"] content = params["values"]["content"] diff --git a/griptape/tools/sql/tool.py b/griptape/tools/sql/tool.py index a84bb87bed..aca41b4ac6 100644 --- a/griptape/tools/sql/tool.py +++ b/griptape/tools/sql/tool.py @@ -5,11 +5,12 @@ from attrs import define, field from schema import Schema -from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact +from griptape.artifacts import ErrorArtifact, InfoArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity if TYPE_CHECKING: + from griptape.artifacts import TableArtifact from griptape.loaders import SqlLoader @@ -43,7 +44,7 @@ def table_schema(self) -> Optional[str]: "schema": Schema({"sql_query": str}), }, ) - def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + def execute_query(self, params: dict) -> TableArtifact | InfoArtifact | ErrorArtifact: try: query = params["values"]["sql_query"] rows = self.sql_loader.load(query) @@ -51,6 +52,6 @@ def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArti return ErrorArtifact(f"error executing query: {e}") if len(rows) > 0: - return ListArtifact(rows) + return rows else: return InfoArtifact("No results found") diff --git a/tests/unit/artifacts/test_action_artifact.py b/tests/unit/artifacts/test_action_artifact.py index 2530ed8c36..b7180b1c3d 100644 --- a/tests/unit/artifacts/test_action_artifact.py +++ b/tests/unit/artifacts/test_action_artifact.py @@ -11,10 +11,6 @@ class TestActionArtifact: def action(self) -> ToolAction: return ToolAction(tag="TestTag", name="TestName", path="TestPath", input={"foo": "bar"}) - def test___add__(self, action): - with pytest.raises(NotImplementedError): - ActionArtifact(action) + ActionArtifact(action) - def test_to_text(self, action): assert ActionArtifact(action).to_text() == json.dumps(action.to_dict()) diff --git a/tests/unit/artifacts/test_audio_artifact.py b/tests/unit/artifacts/test_audio_artifact.py index 6d44c05b3e..aab6af6305 100644 --- a/tests/unit/artifacts/test_audio_artifact.py +++ b/tests/unit/artifacts/test_audio_artifact.py @@ -6,20 +6,22 @@ class TestAudioArtifact: @pytest.fixture() def audio_artifact(self): - return AudioArtifact(value=b"some binary audio data", format="pcm", model="provider/model", prompt="two words") + return AudioArtifact( + value=b"some binary audio data", format="pcm", meta={"model": "provider/model", "prompt": "two words"} + ) def test_mime_type(self, audio_artifact: AudioArtifact): assert audio_artifact.mime_type == "audio/pcm" def test_to_text(self, audio_artifact: AudioArtifact): - assert audio_artifact.to_text() == "Media, type: audio/pcm, size: 22 bytes" + assert audio_artifact.to_text() == "Audio, format: pcm, size: 22 bytes" def test_to_dict(self, audio_artifact: AudioArtifact): audio_dict = audio_artifact.to_dict() assert audio_dict["format"] == "pcm" - assert audio_dict["model"] == "provider/model" - assert audio_dict["prompt"] == "two words" + assert audio_dict["meta"]["model"] == "provider/model" + assert audio_dict["meta"]["prompt"] == "two words" assert audio_dict["value"] == "c29tZSBiaW5hcnkgYXVkaW8gZGF0YQ==" def test_deserialization(self, audio_artifact): @@ -31,5 +33,5 @@ def test_deserialization(self, audio_artifact): assert deserialized_artifact.value == b"some binary audio data" assert deserialized_artifact.mime_type == "audio/pcm" assert deserialized_artifact.format == "pcm" - assert deserialized_artifact.model == "provider/model" - assert deserialized_artifact.prompt == "two words" + assert deserialized_artifact.meta["model"] == "provider/model" + assert deserialized_artifact.meta["prompt"] == "two words" diff --git a/tests/unit/artifacts/test_base_artifact.py b/tests/unit/artifacts/test_base_artifact.py index 6cf8f4466f..e2060879f1 100644 --- a/tests/unit/artifacts/test_base_artifact.py +++ b/tests/unit/artifacts/test_base_artifact.py @@ -41,7 +41,7 @@ def test_list_artifact_from_dict(self): assert artifact.to_text() == "foobar" def test_blob_artifact_from_dict(self): - dict_value = {"type": "BlobArtifact", "value": b"Zm9vYmFy", "dir_name": "foo", "name": "bar"} + dict_value = {"type": "BlobArtifact", "value": b"Zm9vYmFy", "name": "bar"} artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, BlobArtifact) @@ -51,17 +51,15 @@ def test_image_artifact_from_dict(self): dict_value = { "type": "ImageArtifact", "value": b"aW1hZ2UgZGF0YQ==", - "dir_name": "foo", "format": "png", "width": 256, "height": 256, - "model": "test-model", - "prompt": "some prompt", + "meta": {"model": "test-model", "prompt": "some prompt"}, } artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, ImageArtifact) - assert artifact.to_text() == "Media, type: image/png, size: 10 bytes" + assert artifact.to_text() == "aW1hZ2UgZGF0YQ==" assert artifact.value == b"image data" def test_unsupported_from_dict(self): diff --git a/tests/unit/artifacts/test_base_media_artifact.py b/tests/unit/artifacts/test_base_media_artifact.py deleted file mode 100644 index c85d070fef..0000000000 --- a/tests/unit/artifacts/test_base_media_artifact.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from attrs import define - -from griptape.artifacts import MediaArtifact - - -class TestMediaArtifact: - @define - class ImaginaryMediaArtifact(MediaArtifact): - media_type: str = "imagination" - - @pytest.fixture() - def media_artifact(self): - return self.ImaginaryMediaArtifact(value=b"some binary dream data", format="dream") - - def test_to_dict(self, media_artifact): - image_dict = media_artifact.to_dict() - - assert image_dict["format"] == "dream" - assert image_dict["value"] == "c29tZSBiaW5hcnkgZHJlYW0gZGF0YQ==" - - def test_name(self, media_artifact): - assert media_artifact.name.startswith("imagination_artifact") - assert media_artifact.name.endswith(".dream") - - def test_mime_type(self, media_artifact): - assert media_artifact.mime_type == "imagination/dream" - - def test_to_text(self, media_artifact): - assert media_artifact.to_text() == "Media, type: imagination/dream, size: 22 bytes" diff --git a/tests/unit/artifacts/test_blob_artifact.py b/tests/unit/artifacts/test_blob_artifact.py index 3d88d57934..d0f04b1e57 100644 --- a/tests/unit/artifacts/test_blob_artifact.py +++ b/tests/unit/artifacts/test_blob_artifact.py @@ -1,5 +1,4 @@ import base64 -import os import pytest @@ -30,32 +29,22 @@ def test_to_text_encoding_error_handler(self): ) def test_to_dict(self): - assert BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo").to_dict()["name"] == "foobar.txt" - - def test_full_path_with_path(self): - assert BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo").full_path == os.path.normpath( - "foo/foobar.txt" - ) - - def test_full_path_without_path(self): - assert BlobArtifact(b"foobar", name="foobar.txt").full_path == "foobar.txt" + assert BlobArtifact(b"foobar", name="foobar.txt").to_dict()["name"] == "foobar.txt" def test_serialization(self): - artifact = BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo") + artifact = BlobArtifact(b"foobar", name="foobar.txt") artifact_dict = artifact.to_dict() assert artifact_dict["name"] == "foobar.txt" - assert artifact_dict["dir_name"] == "foo" assert base64.b64decode(artifact_dict["value"]) == b"foobar" def test_deserialization(self): - artifact = BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo") + artifact = BlobArtifact(b"foobar", name="foobar.txt") artifact_dict = artifact.to_dict() deserialized_artifact = BaseArtifact.from_dict(artifact_dict) assert isinstance(deserialized_artifact, BlobArtifact) assert deserialized_artifact.name == "foobar.txt" - assert deserialized_artifact.dir_name == "foo" assert deserialized_artifact.value == b"foobar" def test_name(self): diff --git a/tests/unit/artifacts/test_boolean_artifact.py b/tests/unit/artifacts/test_boolean_artifact.py deleted file mode 100644 index 57bbf16622..0000000000 --- a/tests/unit/artifacts/test_boolean_artifact.py +++ /dev/null @@ -1,37 +0,0 @@ -# ruff: noqa: FBT003 -import pytest - -from griptape.artifacts import BooleanArtifact - - -class TestBooleanArtifact: - def test_parse_bool(self): - assert BooleanArtifact.parse_bool("true").value is True - assert BooleanArtifact.parse_bool("false").value is False - assert BooleanArtifact.parse_bool("True").value is True - assert BooleanArtifact.parse_bool("False").value is False - - with pytest.raises(ValueError): - BooleanArtifact.parse_bool("foo") - - with pytest.raises(ValueError): - BooleanArtifact.parse_bool(None) # pyright: ignore[reportArgumentType] - - assert BooleanArtifact.parse_bool(True).value is True - assert BooleanArtifact.parse_bool(False).value is False - - def test_add(self): - with pytest.raises(ValueError): - BooleanArtifact(True) + BooleanArtifact(True) # pyright: ignore[reportUnusedExpression] - - def test_value_type_conversion(self): - assert BooleanArtifact(1).value is True - assert BooleanArtifact(0).value is False - assert BooleanArtifact(True).value is True - assert BooleanArtifact(False).value is False - assert BooleanArtifact("true").value is True - assert BooleanArtifact("false").value is True - assert BooleanArtifact([1]).value is True - assert BooleanArtifact([]).value is False - assert BooleanArtifact(False).value is False - assert BooleanArtifact(True).value is True diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py index 986ece4091..bed46064d8 100644 --- a/tests/unit/artifacts/test_csv_row_artifact.py +++ b/tests/unit/artifacts/test_csv_row_artifact.py @@ -7,12 +7,6 @@ def test_value_type_conversion(self): assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == {"foo": {"bar": "baz"}} assert CsvRowArtifact('{"foo": "bar"}').value == {"foo": "bar"} - def test___add__(self): - assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == { - "test1": "foo", - "test2": "bar", - } - def test_to_text(self): assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}, delimiter="|").to_text() == '"foo|bar"|1' diff --git a/tests/unit/artifacts/test_image_artifact.py b/tests/unit/artifacts/test_image_artifact.py index a722ebd911..b048b3372e 100644 --- a/tests/unit/artifacts/test_image_artifact.py +++ b/tests/unit/artifacts/test_image_artifact.py @@ -11,12 +11,11 @@ def image_artifact(self): format="png", width=512, height=512, - model="openai/dalle2", - prompt="a cute cat", + meta={"model": "openai/dalle2", "prompt": "a cute cat"}, ) def test_to_text(self, image_artifact: ImageArtifact): - assert image_artifact.to_text() == "Media, type: image/png, size: 26 bytes" + assert image_artifact.to_text() == "c29tZSBiaW5hcnkgcG5nIGltYWdlIGRhdGE=" def test_to_dict(self, image_artifact: ImageArtifact): image_dict = image_artifact.to_dict() @@ -24,8 +23,8 @@ def test_to_dict(self, image_artifact: ImageArtifact): assert image_dict["format"] == "png" assert image_dict["width"] == 512 assert image_dict["height"] == 512 - assert image_dict["model"] == "openai/dalle2" - assert image_dict["prompt"] == "a cute cat" + assert image_dict["meta"]["model"] == "openai/dalle2" + assert image_dict["meta"]["prompt"] == "a cute cat" assert image_dict["value"] == "c29tZSBiaW5hcnkgcG5nIGltYWdlIGRhdGE=" def test_deserialization(self, image_artifact): @@ -39,5 +38,5 @@ def test_deserialization(self, image_artifact): assert deserialized_artifact.format == "png" assert deserialized_artifact.width == 512 assert deserialized_artifact.height == 512 - assert deserialized_artifact.model == "openai/dalle2" - assert deserialized_artifact.prompt == "a cute cat" + assert deserialized_artifact.meta["model"] == "openai/dalle2" + assert deserialized_artifact.meta["prompt"] == "a cute cat" diff --git a/tests/unit/artifacts/test_json_artifact.py b/tests/unit/artifacts/test_json_artifact.py index 06f5d6297d..f766635399 100644 --- a/tests/unit/artifacts/test_json_artifact.py +++ b/tests/unit/artifacts/test_json_artifact.py @@ -1,8 +1,6 @@ import json -import pytest - -from griptape.artifacts import JsonArtifact, TextArtifact +from griptape.artifacts import JsonArtifact class TestJsonArtifact: @@ -16,10 +14,6 @@ def test_value_type_conversion(self): assert JsonArtifact(None).value == json.loads(json.dumps(None)) assert JsonArtifact("foo").value == json.loads(json.dumps("foo")) - def test___add__(self): - with pytest.raises(NotImplementedError): - JsonArtifact({"foo": "bar"}) + TextArtifact("invalid json") - def test_to_text(self): assert JsonArtifact({"foo": "bar"}).to_text() == json.dumps({"foo": "bar"}) assert JsonArtifact({"foo": 1}).to_text() == json.dumps({"foo": 1}) diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index 9aa4d3f4f2..05e669b661 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -60,5 +60,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "stability.stable-diffusion-xl-v1" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "stability.stable-diffusion-xl-v1" + assert image_artifact.meta["prompt"] == "test prompt" diff --git a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py index 268708b2b7..a727642111 100644 --- a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py @@ -28,7 +28,10 @@ def test_init(self, driver): def test_init_requires_endpoint(self): with pytest.raises(TypeError): AzureOpenAiImageGenerationDriver( - model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512" + model="dall-e-3", + client=Mock(), + azure_deployment="dalle-deployment", + image_size="512x512", ) # pyright: ignore[reportCallIssues] def test_try_text_to_image(self, driver): @@ -40,5 +43,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-3" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-3" + assert image_artifact.meta["prompt"] == "test prompt" diff --git a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py index 48805cde63..ec70e2dd2f 100644 --- a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py @@ -76,5 +76,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "test_model_id" - assert image_artifact.prompt == "test_prompt" + assert image_artifact.meta["model"] == "test_model_id" + assert image_artifact.meta["prompt"] == "test_prompt" diff --git a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py index 16bcd28701..ff5528fb62 100644 --- a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py @@ -22,8 +22,8 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-2" + assert image_artifact.meta["prompt"] == "test prompt" def test_try_image_variation(self, driver): driver.client.images.create_variation.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")]) @@ -34,7 +34,7 @@ def test_try_image_variation(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" + assert image_artifact.meta["model"] == "dall-e-2" def test_try_image_variation_invalid_size(self, driver): driver.image_size = "1024x1792" @@ -59,8 +59,8 @@ def test_try_image_inpainting(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-2" + assert image_artifact.meta["prompt"] == "test prompt" def test_try_image_inpainting_invalid_size(self, driver): driver.image_size = "1024x1792" diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index 473fd0d9e2..b7ebdd9120 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -13,14 +13,13 @@ def loader(self): def create_source(self, bytes_from_resource_path): return bytes_from_resource_path - @pytest.mark.parametrize(("resource_path", "suffix", "mime_type"), [("sentences.wav", ".wav", "audio/wav")]) - def test_load(self, resource_path, suffix, mime_type, loader, create_source): + @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) + def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) artifact = loader.load(source) assert isinstance(artifact, AudioArtifact) - assert artifact.name.endswith(suffix) assert artifact.mime_type == mime_type assert len(artifact.value) > 0 @@ -35,6 +34,5 @@ def test_load_collection(self, create_source, loader): for key in collection: artifact = collection[key] assert isinstance(artifact, AudioArtifact) - assert artifact.name.endswith(".wav") assert artifact.mime_type == "audio/wav" assert len(artifact.value) > 0 diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index a747afff71..a63322290f 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -11,11 +11,11 @@ def loader(self, request): if encoding is None: return CsvLoader(embedding_driver=MockEmbeddingDriver()) else: - return CsvLoader(embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return CsvLoader(encoding=encoding, embedding_driver=MockEmbeddingDriver()) @pytest.fixture() def loader_with_pipe_delimiter(self): - return CsvLoader(embedding_driver=MockEmbeddingDriver(), delimiter="|") + return CsvLoader(delimiter="|", embedding_driver=MockEmbeddingDriver()) @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) def create_source(self, request): @@ -24,24 +24,24 @@ def create_source(self, request): def test_load(self, loader, create_source): source = create_source("test-1.csv") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 10 - first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "foo1" - assert first_artifact.value["Bar"] == "bar1" - assert first_artifact.embedding == [0, 1] + assert len(artifact) == 10 + first_artifact = artifact.value[0] + assert first_artifact["Foo"] == "foo1" + assert first_artifact["Bar"] == "bar1" + assert artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): source = create_source("test-pipe.csv") - artifacts = loader_with_pipe_delimiter.load(source) + artifact = loader_with_pipe_delimiter.load(source) - assert len(artifacts) == 10 - first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "bar1" - assert first_artifact.value["Bar"] == "foo1" - assert first_artifact.embedding == [0, 1] + assert len(artifact) == 10 + first_artifact = artifact.value[0] + assert first_artifact["Foo"] == "bar1" + assert first_artifact["Bar"] == "foo1" + assert artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): resource_paths = ["test-1.csv", "test-2.csv"] @@ -53,9 +53,19 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys for key in keys: - artifacts = collection[key] - assert len(artifacts) == 10 - first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "foo1" - assert first_artifact.value["Bar"] == "bar1" - assert first_artifact.embedding == [0, 1] + artifact = collection[key] + assert len(artifact) == 10 + first_artifact = artifact.value[0] + assert first_artifact["Foo"] == "foo1" + assert first_artifact["Bar"] == "bar1" + assert artifact.embedding == [0, 1] + + def test_to_text(self, loader, create_source): + source = create_source("test-1.csv") + + text = loader.load(source).to_text() + + assert ( + text + == "Foo,Bar\nfoo1,bar1\nfoo2,bar2\nfoo3,bar3\nfoo4,bar4\nfoo5,bar5\nfoo6,bar6\nfoo7,bar7\nfoo8,bar8\nfoo9,bar9\nfoo10,bar10" + ) diff --git a/tests/unit/loaders/test_dataframe_loader.py b/tests/unit/loaders/test_dataframe_loader.py index 5c2a57ed6a..fa9a540844 100644 --- a/tests/unit/loaders/test_dataframe_loader.py +++ b/tests/unit/loaders/test_dataframe_loader.py @@ -16,14 +16,14 @@ def test_load_with_path(self, loader): # test loading a file delimited by comma path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - artifacts = loader.load(pd.read_csv(path)) + artifact = loader.load(pd.read_csv(path)) - assert len(artifacts) == 10 - first_artifact = artifacts[0].value + assert len(artifact) == 10 + first_artifact = artifact.value[0] assert first_artifact["Foo"] == "foo1" assert first_artifact["Bar"] == "bar1" - assert artifacts[0].embedding == [0, 1] + assert artifact.embedding == [0, 1] def test_load_collection_with_path(self, loader): path1 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") @@ -37,16 +37,16 @@ def test_load_collection_with_path(self, loader): assert list(collection.keys()) == [key1, key2] - artifacts = collection[key1] - assert len(artifacts) == 10 - first_artifact = artifacts[0].value + artifact = collection[key1] + assert len(artifact) == 10 + first_artifact = artifact.value[0] assert first_artifact["Foo"] == "foo1" assert first_artifact["Bar"] == "bar1" - artifacts = collection[key2] - assert len(artifacts) == 10 - first_artifact = artifacts[0].value + artifact = collection[key2] + assert len(artifact) == 10 + first_artifact = artifact.value[0] assert first_artifact["Bar"] == "bar1" assert first_artifact["Foo"] == "foo1" - assert artifacts[0].embedding == [0, 1] + assert artifact.embedding == [0, 1] diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index eca4cbccc5..7093894b00 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -18,23 +18,22 @@ def create_source(self, bytes_from_resource_path): return bytes_from_resource_path @pytest.mark.parametrize( - ("resource_path", "suffix", "mime_type"), + ("resource_path", "mime_type"), [ - ("small.png", ".png", "image/png"), - ("small.jpg", ".jpeg", "image/jpeg"), - ("small.webp", ".webp", "image/webp"), - ("small.bmp", ".bmp", "image/bmp"), - ("small.gif", ".gif", "image/gif"), - ("small.tiff", ".tiff", "image/tiff"), + ("small.png", "image/png"), + ("small.jpg", "image/jpeg"), + ("small.webp", "image/webp"), + ("small.bmp", "image/bmp"), + ("small.gif", "image/gif"), + ("small.tiff", "image/tiff"), ], ) - def test_load(self, resource_path, suffix, mime_type, loader, create_source): + def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) artifact = loader.load(source) assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(suffix) assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == mime_type @@ -49,7 +48,6 @@ def test_load_normalize(self, resource_path, png_loader, create_source): artifact = png_loader.load(source) assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(".png") assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == "image/png" @@ -68,7 +66,6 @@ def test_load_collection(self, create_source, png_loader): for key in keys: artifact = collection[key] assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(".png") assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == "image/png" diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index fbfa6d4fa9..e977d3c5f5 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -35,14 +35,14 @@ def loader(self): return sql_loader def test_load(self, loader): - artifacts = loader.load("SELECT * FROM test_table;") + artifact = loader.load("SELECT * FROM test_table;") - assert len(artifacts) == 3 - assert artifacts[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} - assert artifacts[1].value == {"id": 2, "name": "Bob", "age": 30, "city": "Los Angeles"} - assert artifacts[2].value == {"id": 3, "name": "Charlie", "age": 22, "city": "Chicago"} + assert len(artifact) == 3 + assert artifact.value[0] == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} + assert artifact.value[1] == {"id": 2, "name": "Bob", "age": 30, "city": "Los Angeles"} + assert artifact.value[2] == {"id": 3, "name": "Charlie", "age": 22, "city": "Chicago"} - assert artifacts[0].embedding == [0, 1] + assert artifact.embedding == [0, 1] def test_load_collection(self, loader): artifacts = loader.load_collection(["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"]) @@ -52,10 +52,10 @@ def test_load_collection(self, loader): loader.to_key("SELECT * FROM test_table LIMIT 2;"), ] - assert [a.value for artifact_list in artifacts.values() for a in artifact_list] == [ + assert [a for artifact_table in artifacts.values() for a in artifact_table.value] == [ {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, {"age": 30, "city": "Los Angeles", "id": 2, "name": "Bob"}, ] - assert list(artifacts.values())[0][0].embedding == [0, 1] + assert list(artifacts.values())[0].embedding == [0, 1] diff --git a/tests/unit/tools/test_inpainting_image_generation_tool.py b/tests/unit/tools/test_inpainting_image_generation_tool.py index 45afcbc63a..a558921a94 100644 --- a/tests/unit/tools/test_inpainting_image_generation_tool.py +++ b/tests/unit/tools/test_inpainting_image_generation_tool.py @@ -59,8 +59,8 @@ def test_image_inpainting_with_outfile( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_inpainting_from_file( @@ -83,8 +83,8 @@ def test_image_inpainting_from_memory(self, image_generation_engine, image_artif memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_inpainting_from_memory( diff --git a/tests/unit/tools/test_outpainting_image_variation_tool.py b/tests/unit/tools/test_outpainting_image_variation_tool.py index 4fbcbe8d49..e3f0de847b 100644 --- a/tests/unit/tools/test_outpainting_image_variation_tool.py +++ b/tests/unit/tools/test_outpainting_image_variation_tool.py @@ -34,8 +34,8 @@ def test_validate_output_configs(self, image_generation_engine) -> None: OutpaintingImageGenerationTool(engine=image_generation_engine, output_dir="test", output_file="test") def test_image_outpainting(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = Mock( - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_outpainting_from_file( @@ -59,8 +59,8 @@ def test_image_outpainting_with_outfile( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_outpainting_from_file( diff --git a/tests/unit/tools/test_prompt_image_generation_tool.py b/tests/unit/tools/test_prompt_image_generation_tool.py index a0c5c7037e..4252d887e9 100644 --- a/tests/unit/tools/test_prompt_image_generation_tool.py +++ b/tests/unit/tools/test_prompt_image_generation_tool.py @@ -5,6 +5,7 @@ import pytest +from griptape.artifacts.image_artifact import ImageArtifact from griptape.tools import PromptImageGenerationTool @@ -36,8 +37,8 @@ def test_generate_image_with_outfile(self, image_generation_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = PromptImageGenerationTool(engine=image_generation_engine, output_file=outfile) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.generate_image( diff --git a/tests/unit/tools/test_sql_tool.py b/tests/unit/tools/test_sql_tool.py index 2ef50ff549..e800d6fe46 100644 --- a/tests/unit/tools/test_sql_tool.py +++ b/tests/unit/tools/test_sql_tool.py @@ -26,7 +26,7 @@ def test_execute_query(self, driver): result = client.execute_query({"values": {"sql_query": "SELECT * from test_table;"}}) assert len(result.value) == 1 - assert result.value[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} + assert result.value[0] == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} def test_execute_query_description(self, driver): client = SqlTool( diff --git a/tests/unit/tools/test_text_to_speech_tool.py b/tests/unit/tools/test_text_to_speech_tool.py index 8821d48fc3..6f2c43bd39 100644 --- a/tests/unit/tools/test_text_to_speech_tool.py +++ b/tests/unit/tools/test_text_to_speech_tool.py @@ -5,6 +5,7 @@ import pytest +from griptape.artifacts.audio_artifact import AudioArtifact from griptape.tools.text_to_speech.tool import TextToSpeechTool @@ -32,7 +33,7 @@ def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile) - text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] + text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) diff --git a/tests/unit/tools/test_variation_image_generation_tool.py b/tests/unit/tools/test_variation_image_generation_tool.py index c4528a044e..5fd3513c1c 100644 --- a/tests/unit/tools/test_variation_image_generation_tool.py +++ b/tests/unit/tools/test_variation_image_generation_tool.py @@ -58,8 +58,8 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_variation_from_file(