From 86100db3c300df865042681328753d5894fb366e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Sep 2024 11:35:13 -0700 Subject: [PATCH 1/2] Show mixing and matching Drivers in custom example (#1168) --- docs/griptape-framework/structures/src/drivers_config_7.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/griptape-framework/structures/src/drivers_config_7.py b/docs/griptape-framework/structures/src/drivers_config_7.py index 3b1d396ce2..efec4f3cff 100644 --- a/docs/griptape-framework/structures/src/drivers_config_7.py +++ b/docs/griptape-framework/structures/src/drivers_config_7.py @@ -2,14 +2,15 @@ from griptape.configs import Defaults from griptape.configs.drivers import DriversConfig -from griptape.drivers import AnthropicPromptDriver +from griptape.drivers import AnthropicPromptDriver, OpenAiEmbeddingDriver from griptape.structures import Agent Defaults.drivers_config = DriversConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], - ) + ), + embedding_driver=OpenAiEmbeddingDriver(), ) From 37d558240c0774862974a8acb5bc2b1af1579c08 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Sep 2024 15:31:59 -0700 Subject: [PATCH 2/2] Refactor Artifacts (#1114) --- CHANGELOG.md | 21 +++ MIGRATION.md | 136 ++++++++++++++++++ docs/griptape-framework/data/artifacts.md | 56 +++----- docs/griptape-framework/data/loaders.md | 6 +- griptape/artifacts/__init__.py | 4 - griptape/artifacts/action_artifact.py | 13 +- griptape/artifacts/audio_artifact.py | 24 +++- griptape/artifacts/base_artifact.py | 38 ++--- griptape/artifacts/blob_artifact.py | 29 ++-- griptape/artifacts/boolean_artifact.py | 19 ++- griptape/artifacts/csv_row_artifact.py | 34 ----- griptape/artifacts/error_artifact.py | 11 +- griptape/artifacts/generic_artifact.py | 10 +- griptape/artifacts/image_artifact.py | 27 ++-- griptape/artifacts/info_artifact.py | 12 +- griptape/artifacts/json_artifact.py | 20 ++- griptape/artifacts/list_artifact.py | 33 +++-- griptape/artifacts/media_artifact.py | 53 ------- griptape/artifacts/text_artifact.py | 24 ++-- griptape/common/prompt_stack/prompt_stack.py | 6 +- .../amazon_bedrock_image_generation_driver.py | 12 +- ...ngface_pipeline_image_generation_driver.py | 4 +- .../leonardo_image_generation_driver.py | 12 +- .../openai_image_generation_driver.py | 3 +- .../extraction/csv_extraction_engine.py | 16 ++- griptape/loaders/csv_loader.py | 15 +- griptape/loaders/dataframe_loader.py | 15 +- griptape/loaders/sql_loader.py | 15 +- ...mixin.py => artifact_file_output_mixin.py} | 8 +- griptape/schemas/base_schema.py | 6 +- griptape/tasks/base_audio_generation_task.py | 4 +- griptape/tasks/base_image_generation_task.py | 9 +- griptape/tasks/tool_task.py | 6 +- griptape/tools/base_image_generation_tool.py | 4 +- griptape/tools/query/tool.py | 4 +- griptape/tools/text_to_speech/tool.py | 4 +- tests/unit/artifacts/test_action_artifact.py | 4 - tests/unit/artifacts/test_audio_artifact.py | 14 +- tests/unit/artifacts/test_base_artifact.py | 8 +- .../artifacts/test_base_media_artifact.py | 30 ---- tests/unit/artifacts/test_blob_artifact.py | 27 ++-- tests/unit/artifacts/test_boolean_artifact.py | 11 ++ tests/unit/artifacts/test_csv_row_artifact.py | 30 ---- tests/unit/artifacts/test_image_artifact.py | 13 +- tests/unit/artifacts/test_json_artifact.py | 14 +- tests/unit/artifacts/test_list_artifact.py | 12 +- tests/unit/artifacts/test_text_artifact.py | 5 + ...table_diffusion_image_generation_driver.py | 4 +- ...st_azure_openai_image_generation_driver.py | 9 +- .../test_leonardo_image_generation_driver.py | 4 +- .../test_openai_image_generation_driver.py | 10 +- .../test_base_local_vector_store_driver.py | 15 -- .../extraction/test_csv_extraction_engine.py | 6 +- tests/unit/loaders/test_audio_loader.py | 6 +- tests/unit/loaders/test_csv_loader.py | 29 ++-- tests/unit/loaders/test_dataframe_loader.py | 52 ------- tests/unit/loaders/test_image_loader.py | 19 ++- tests/unit/loaders/test_sql_loader.py | 17 +-- tests/unit/memory/tool/test_task_memory.py | 6 +- .../test_image_artifact_file_output_mixin.py | 12 +- tests/unit/tasks/test_extraction_task.py | 2 +- tests/unit/tools/test_extraction_tool.py | 4 +- tests/unit/tools/test_file_manager.py | 25 +--- .../test_inpainting_image_generation_tool.py | 8 +- .../test_outpainting_image_variation_tool.py | 8 +- .../test_prompt_image_generation_tool.py | 5 +- tests/unit/tools/test_sql_tool.py | 2 +- tests/unit/tools/test_text_to_speech_tool.py | 3 +- .../test_variation_image_generation_tool.py | 4 +- 69 files changed, 574 insertions(+), 557 deletions(-) delete mode 100644 griptape/artifacts/csv_row_artifact.py delete mode 100644 griptape/artifacts/media_artifact.py rename griptape/mixins/{media_artifact_file_output_mixin.py => artifact_file_output_mixin.py} (87%) delete mode 100644 tests/unit/artifacts/test_base_media_artifact.py delete mode 100644 tests/unit/artifacts/test_csv_row_artifact.py delete mode 100644 tests/unit/loaders/test_dataframe_loader.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fed1b2009..4516c59b62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added +- `BaseArtifact.to_bytes()` method to convert an Artifact's value to bytes. +- `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 string. +- `CsvLoader`/`SqlLoader`/`DataframeLoader` `formatter_fn` field for customizing how SQL results are formatted into `TextArtifact`s. + +### Changed +- **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. +- **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. +- **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. +- **BREAKING**: Removed `ImageArtifact.media_type`. +- **BREAKING**: Removed `AudioArtifact.media_type`. +- **BREAKING**: Removed `BlobArtifact.dir_name`. +- **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. +- **BREAKING**: `ImageArtifact.format` is now required. +- Updated `JsonArtifact` value converter to properly handle more types. +- `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. +- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. +- Removed `__add__` method from `BaseArtifact`, implemented it where necessary. +- Generic type support to `ListArtifact`. +- Iteration support to `ListArtifact`. + ## [0.31.0] - 2024-09-03 **Note**: This release includes breaking changes. Please refer to the [Migration Guide](./MIGRATION.md#030x-to-031x) for details. diff --git a/MIGRATION.md b/MIGRATION.md index af8835e5b6..016a93f03d 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,6 +1,142 @@ # Migration Guide This document provides instructions for migrating your codebase to accommodate breaking changes introduced in new versions of Griptape. +## 0.31.X to 0.32.X + +### Removed `MediaArtifact` + +`MediaArtifact` has been removed. Use `ImageArtifact` or `AudioArtifact` instead. + +#### Before + +```python +image_media = MediaArtifact( + b"image_data", + media_type="image", + format="jpeg" +) + +audio_media = MediaArtifact( + b"audio_data", + media_type="audio", + format="wav" +) +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg" +) + +audio_artifact = AudioArtifact( + b"audio_data", + format="wav" +) +``` + +### `ImageArtifact.format` is now required + +`ImageArtifact.format` is now a required parameter. Update any code that does not provide a `format` parameter. + +#### Before + +```python +image_artifact = ImageArtifact( + b"image_data" +) +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg" +) +``` + +### Removed `CsvRowArtifact` + +`CsvRowArtifact` has been removed. Use `TextArtifact` instead. + +#### Before + +```python +artifact = CsvRowArtifact({"name": "John", "age": 30}) +print(artifact.value) # {"name": "John", "age": 30} +print(type(artifact.value)) # +``` + +#### After +```python +artifact = TextArtifact("name: John\nage: 30") +print(artifact.value) # name: John\nage: 30 +print(type(artifact.value)) # +``` + +If you require storing a dictionary as an Artifact, you can use `GenericArtifact` instead. + +### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types + +`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`. + +If you require a dictionary, set a custom `formatter_fn` and then parse the text to a dictionary. + +#### Before + +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # +``` + +#### After +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # name: John\nAge: 30 +print(type(results[0].value)) # + +# Customize formatter_fn +results = CsvLoader(formatter_fn=lambda x: json.dumps(x)).load(Path("people.csv").read_text()) +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # + +dict_results = [json.loads(result.value) for result in results] +print(dict_results[0]) # {"name": "John", "age": 30} +print(type(dict_results[0])) # +``` + +### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` + +`ImageArtifact.prompt` and `ImageArtifact.model` have been moved to `ImageArtifact.meta`. + +#### Before + +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg", + prompt="Generate an image of a cat", + model="DALL-E" +) + +print(image_artifact.prompt, image_artifact.model) # Generate an image of a cat, DALL-E +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg", + meta={"prompt": "Generate an image of a cat", "model": "DALL-E"} +) + +print(image_artifact.meta["prompt"], image_artifact.meta["model"]) # Generate an image of a cat, DALL-E +``` + ## 0.30.X to 0.31.X diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index 8c4da02b32..2edd1ebec2 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -5,60 +5,50 @@ search: ## Overview -**[Artifacts](../../reference/griptape/artifacts/base_artifact.md)** are used for passing different types of data between Griptape components. All tools return artifacts that are later consumed by tasks and task memory. -Artifacts make sure framework components enforce contracts when passing and consuming data. +**[Artifacts](../../reference/griptape/artifacts/base_artifact.md)** are the core data structure in Griptape. They are used to encapsulate data and enhance it with metadata. ## Text -A [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) for passing text data of arbitrary size around the framework. It can be used to count tokens with [token_count()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.token_count) with a tokenizer. -It can also be used to generate a text embedding with [generate_embedding()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.generate_embedding) -and access it with [embedding](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.embedding). +[TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s store textual data. They offer methods such as [token_count()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.token_count) for counting tokens with a tokenizer, and [generate_embedding()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.generate_embedding) for creating text embeddings. You can also access the embedding via the [embedding](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.embedding) property. -[TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s returned by tool activities and returns artifact IDs back to the LLM. +When `TextArtifact`s are returned from Tools, they will be stored in [Task Memory](../../griptape-framework/structures/task-memory.md) if the Tool has set `off_prompt=True`. -## Csv Row +## Blob -A [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md) for passing structured row data around the framework. It inherits from [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) and overrides the -[to_text()](../../reference/griptape/artifacts/csv_row_artifact.md#griptape.artifacts.csv_row_artifact.CsvRowArtifact.to_text) method, which always returns a valid CSV row. +[BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s store binary large objects (blobs). -## Info +When `BlobArtifact`s are returned from Tools, they will be stored in [Task Memory](../../griptape-framework/structures/task-memory.md) if the Tool has set `off_prompt=True`. -An [InfoArtifact](../../reference/griptape/artifacts/info_artifact.md) for passing short notifications back to the LLM without task memory storing them. +### Image -## Error +[ImageArtifact](../../reference/griptape/artifacts/image_artifact.md)s store image data. This includes binary image data along with metadata such as MIME type and dimensions. They are a subclass of [BlobArtifacts](#blob). -An [ErrorArtifact](../../reference/griptape/artifacts/error_artifact.md) is used for passing errors back to the LLM without task memory storing them. +### Audio -## Blob +[AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md)s store audio content. This includes binary audio data and metadata such as format, and duration. They are a subclass of [BlobArtifacts](#blob). -A [BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md) for passing binary large objects (blobs) back to the LLM. -Treat it as a way to return unstructured data, such as images, videos, audio, and other files back from tools. -Each blob has a [name](../../reference/griptape/artifacts/base_artifact.md#griptape.artifacts.base_artifact.BaseArtifact.name) and -[dir](../../reference/griptape/artifacts/blob_artifact.md#griptape.artifacts.blob_artifact.BlobArtifact.dir_name) to uniquely identify stored objects. +## List -[TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores [BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s returned by tool activities that can be reused by other tools. +[ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s store lists of Artifacts. -## Image +When `ListArtifact`s are returned from Tools, their elements will be stored in [Task Memory](../../griptape-framework/structures/task-memory.md) if the element is either a `TextArtifact` or a `BlobArtifact` and the Tool has set `off_prompt=True`. -An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an Image Artifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blob). +## Info -## Audio +[InfoArtifact](../../reference/griptape/artifacts/info_artifact.md)s store small pieces of textual information. These are useful for conveying messages about the execution or results of an operation, such as "No results found" or "Operation completed successfully." -An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blob). +## JSON -## Boolean +[JsonArtifact](../../reference/griptape/artifacts/json_artifact.md)s store JSON-serializable data. Any data assigned to the `value` property is processed using `json.dumps(json.loads(value))`. -A [BooleanArtifact](../../reference/griptape/artifacts/boolean_artifact.md) is used for passing boolean values around the framework. +## Error -!!! info - Any object passed on init to `BooleanArtifact` will be coerced into a `bool` type. This might lead to unintended behavior: `BooleanArtifact("False").value is True`. Use [BooleanArtifact.parse_bool](../../reference/griptape/artifacts/boolean_artifact.md#griptape.artifacts.boolean_artifact.BooleanArtifact.parse_bool) to convert case-insensitive string literal values `"True"` and `"False"` into a `BooleanArtifact`: `BooleanArtifact.parse_bool("False").value is False`. +[ErrorArtifact](../../reference/griptape/artifacts/error_artifact.md)s store exception information, providing a structured way to convey errors. -## Generic +## Action -A [GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md) can be used as an escape hatch for passing any type of data around the framework. -It is generally not recommended to use this Artifact type, but it can be used in a handful of situations where no other Artifact type fits the data being passed. -See [talking to a video](../../examples/talk-to-a-video.md) for an example of using a `GenericArtifact` to pass a Gemini-specific video file. +[ActionArtifact](../../reference/griptape/artifacts/action_artifact.md)s represent actions taken by an LLM. Currently, the only supported action type is [ToolAction](../../reference/griptape/common/actions/tool_action.md), which is used to execute a [Tool](../../griptape-framework/tools/index.md). -## Json +## Generic -A [JsonArtifact](../../reference/griptape/artifacts/json_artifact.md) is used for passing JSON-serliazable data around the framework. Anything passed to `value` will be converted using `json.dumps(json.loads(value))`. +[GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md)s provide a flexible way to pass data that does not fit into any other artifact category. While not generally recommended, they can be useful for specific use cases. For instance, see [talking to a video](../../examples/talk-to-a-video.md), which demonstrates using a `GenericArtifact` to pass a Gemini-specific video file. diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 914fdee2a9..0c0fc3eadc 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -22,7 +22,7 @@ Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) ## SQL -Can be used to load data from a SQL database into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: +Can be used to load data from a SQL database into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_2.py" @@ -30,7 +30,7 @@ Can be used to load data from a SQL database into [CsvRowArtifact](../../referen ## CSV -Can be used to load CSV files into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: +Can be used to load CSV files into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_3.py" @@ -42,7 +42,7 @@ Can be used to load CSV files into [CsvRowArtifact](../../reference/griptape/art !!! info This driver requires the `loaders-dataframe` [extra](../index.md#extras). -Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: +Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_4.py" diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index f39bfea8d0..0e58a8a764 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -5,9 +5,7 @@ from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact -from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact -from .media_artifact import MediaArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact from .action_artifact import ActionArtifact @@ -22,9 +20,7 @@ "JsonArtifact", "BlobArtifact", "BooleanArtifact", - "CsvRowArtifact", "ListArtifact", - "MediaArtifact", "ImageArtifact", "AudioArtifact", "ActionArtifact", diff --git a/griptape/artifacts/action_artifact.py b/griptape/artifacts/action_artifact.py index 9772bbbabb..d882d06384 100644 --- a/griptape/artifacts/action_artifact.py +++ b/griptape/artifacts/action_artifact.py @@ -5,15 +5,20 @@ from attrs import define, field from griptape.artifacts import BaseArtifact -from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from griptape.common import ToolAction @define() -class ActionArtifact(BaseArtifact, SerializableMixin): +class ActionArtifact(BaseArtifact): + """Represents the LLM taking an action to use a Tool. + + Attributes: + value: The Action to take. Currently only supports ToolAction. + """ + value: ToolAction = field(metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> ActionArtifact: - raise NotImplementedError + def to_text(self) -> str: + return str(self.value) diff --git a/griptape/artifacts/audio_artifact.py b/griptape/artifacts/audio_artifact.py index 3dc67fa366..e9e38858a4 100644 --- a/griptape/artifacts/audio_artifact.py +++ b/griptape/artifacts/audio_artifact.py @@ -1,12 +1,26 @@ from __future__ import annotations -from attrs import define +from attrs import define, field -from griptape.artifacts import MediaArtifact +from griptape.artifacts import BlobArtifact @define -class AudioArtifact(MediaArtifact): - """AudioArtifact is a type of MediaArtifact representing audio.""" +class AudioArtifact(BlobArtifact): + """Stores audio data. - media_type: str = "audio" + Attributes: + format: The audio format, e.g. "wav" or "mp3". + """ + + format: str = field(kw_only=True, metadata={"serializable": True}) + + @property + def mime_type(self) -> str: + return f"audio/{self.format}" + + def to_bytes(self) -> bytes: + return self.value + + def to_text(self) -> str: + return f"Audio, format: {self.format}, size: {len(self.value)} bytes" diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index 82a0bbd23b..61989ab54a 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import uuid from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional @@ -15,6 +14,20 @@ @define class BaseArtifact(SerializableMixin, ABC): + """Serves as the base class for all Artifacts. + + Artifacts are used to encapsulate data and enhance it with metadata. + + Attributes: + id: The unique identifier of the Artifact. Defaults to a random UUID. + reference: The optional Reference to the Artifact. + meta: The metadata associated with the Artifact. Defaults to an empty dictionary. + name: The name of the Artifact. Defaults to the id. + value: The value of the Artifact. + encoding: The encoding to use when encoding/decoding the value. + encoding_error_handler: The error handler to use when encoding/decoding the value. + """ + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) reference: Optional[Reference] = field(default=None, kw_only=True, metadata={"serializable": True}) meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) @@ -24,22 +37,8 @@ class BaseArtifact(SerializableMixin, ABC): metadata={"serializable": True}, ) value: Any = field() - - @classmethod - def value_to_bytes(cls, value: Any) -> bytes: - if isinstance(value, bytes): - return value - else: - return str(value).encode() - - @classmethod - def value_to_dict(cls, value: Any) -> dict: - dict_value = value if isinstance(value, dict) else json.loads(value) - - return dict(dict_value.items()) - - def to_text(self) -> str: - return str(self.value) + encoding_error_handler: str = field(default="strict", kw_only=True) + encoding: str = field(default="utf-8", kw_only=True) def __str__(self) -> str: return self.to_text() @@ -50,5 +49,8 @@ def __bool__(self) -> bool: def __len__(self) -> int: return len(self.value) + def to_bytes(self) -> bytes: + return self.to_text().encode(encoding=self.encoding, errors=self.encoding_error_handler) + @abstractmethod - def __add__(self, other: BaseArtifact) -> BaseArtifact: ... + def to_text(self) -> str: ... diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 0c0dcc1223..7c814a0526 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -1,7 +1,6 @@ from __future__ import annotations -import os.path -from typing import Optional +import base64 from attrs import define, field @@ -10,17 +9,27 @@ @define class BlobArtifact(BaseArtifact): - value: bytes = field(converter=BaseArtifact.value_to_bytes, metadata={"serializable": True}) - dir_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - encoding: str = field(default="utf-8", kw_only=True) - encoding_error_handler: str = field(default="strict", kw_only=True) + """Stores arbitrary binary data. - def __add__(self, other: BaseArtifact) -> BlobArtifact: - return BlobArtifact(self.value + other.value, name=self.name) + Attributes: + value: The binary data. + """ + + value: bytes = field( + converter=lambda value: value if isinstance(value, bytes) else str(value).encode(), + metadata={"serializable": True}, + ) + + @property + def base64(self) -> str: + return base64.b64encode(self.value).decode(self.encoding) @property - def full_path(self) -> str: - return os.path.join(self.dir_name, self.name) if self.dir_name else self.name + def mime_type(self) -> str: + return "application/octet-stream" + + def to_bytes(self) -> bytes: + return self.value def to_text(self) -> str: return self.value.decode(encoding=self.encoding, errors=self.encoding_error_handler) diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index 5bcdfac9ba..eb135824d4 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -9,17 +9,23 @@ @define class BooleanArtifact(BaseArtifact): + """Stores a boolean value. + + Attributes: + value: The boolean value. + """ + value: bool = field(converter=bool, metadata={"serializable": True}) @classmethod - def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001 - """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" + def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: + """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false".""" if value is not None: if isinstance(value, str): if value.lower() == "true": - return BooleanArtifact(True) # noqa: FBT003 + return BooleanArtifact(value=True) elif value.lower() == "false": - return BooleanArtifact(False) # noqa: FBT003 + return BooleanArtifact(value=False) elif isinstance(value, bool): return BooleanArtifact(value) raise ValueError(f"Cannot convert '{value}' to BooleanArtifact") @@ -28,4 +34,7 @@ def __add__(self, other: BaseArtifact) -> BooleanArtifact: raise ValueError("Cannot add BooleanArtifact with other artifacts") def __eq__(self, value: object) -> bool: - return self.value is value + return self.value == value + + def to_text(self) -> str: + return str(self.value).lower() diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py deleted file mode 100644 index 00f1047fcc..0000000000 --- a/griptape/artifacts/csv_row_artifact.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -import csv -import io - -from attrs import define, field - -from griptape.artifacts import BaseArtifact, TextArtifact - - -@define -class CsvRowArtifact(TextArtifact): - value: dict[str, str] = field(converter=BaseArtifact.value_to_dict, metadata={"serializable": True}) - delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) - - def __add__(self, other: BaseArtifact) -> CsvRowArtifact: - return CsvRowArtifact(self.value | other.value) - - def __bool__(self) -> bool: - return len(self) > 0 - - def to_text(self) -> str: - with io.StringIO() as csvfile: - writer = csv.DictWriter( - csvfile, - fieldnames=self.value.keys(), - quoting=csv.QUOTE_MINIMAL, - delimiter=self.delimiter, - ) - - writer.writeheader() - writer.writerow(self.value) - - return csvfile.getvalue().strip() diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index d065d754b2..27e6a37ab0 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -9,8 +9,15 @@ @define class ErrorArtifact(BaseArtifact): + """Represents an error that may want to be conveyed to the LLM. + + Attributes: + value: The error message. + exception: The exception that caused the error. Defaults to None. + """ + value: str = field(converter=str, metadata={"serializable": True}) exception: Optional[Exception] = field(default=None, kw_only=True, metadata={"serializable": False}) - def __add__(self, other: BaseArtifact) -> ErrorArtifact: - return ErrorArtifact(self.value + other.value) + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/generic_artifact.py b/griptape/artifacts/generic_artifact.py index 8e0b7e38c2..e90f40ef0c 100644 --- a/griptape/artifacts/generic_artifact.py +++ b/griptape/artifacts/generic_artifact.py @@ -9,7 +9,13 @@ @define class GenericArtifact(BaseArtifact): + """Serves as an escape hatch for artifacts that don't fit into any other category. + + Attributes: + value: The value of the Artifact. + """ + value: Any = field(metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - raise NotImplementedError + def to_text(self) -> str: + return str(self.value) diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index e963b38818..36170ee0d8 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -2,22 +2,29 @@ from attrs import define, field -from griptape.artifacts import MediaArtifact +from griptape.artifacts import BlobArtifact @define -class ImageArtifact(MediaArtifact): - """ImageArtifact is a type of MediaArtifact representing an image. +class ImageArtifact(BlobArtifact): + """Stores image data. Attributes: - value: Raw bytes representing media data. - media_type: The type of media, defaults to "image". - format: The format of the media, like png, jpeg, or gif. - name: Artifact name, generated using creation time and a random string. - model: Optionally specify the model used to generate the media. - prompt: Optionally specify the prompt used to generate the media. + format: The format of the image data. Used when building the MIME type. + width: The width of the image. + height: The height of the image """ - media_type: str = "image" + format: str = field(kw_only=True, metadata={"serializable": True}) width: int = field(kw_only=True, metadata={"serializable": True}) height: int = field(kw_only=True, metadata={"serializable": True}) + + @property + def mime_type(self) -> str: + return f"image/{self.format}" + + def to_bytes(self) -> bytes: + return self.value + + def to_text(self) -> str: + return f"Image, format: {self.format}, size: {len(self.value)} bytes" diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 26fe6366bc..3391554e92 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -7,7 +7,15 @@ @define class InfoArtifact(BaseArtifact): + """Represents helpful info that can be conveyed to the LLM. + + For example, "No results found" or "Please try again.". + + Attributes: + value: The info to convey. + """ + value: str = field(converter=str, metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> InfoArtifact: - return InfoArtifact(self.value + other.value) + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index b292879a9c..57700afd5f 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Union +from typing import Any, Union from attrs import define, field @@ -12,10 +12,20 @@ @define class JsonArtifact(BaseArtifact): - value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True}) + """Stores JSON data. + + Attributes: + value: The JSON data. Values will automatically be converted to a JSON-compatible format. + """ + + value: Json = field(converter=lambda value: JsonArtifact.value_to_json(value), metadata={"serializable": True}) + + @classmethod + def value_to_json(cls, value: Any) -> Json: + if isinstance(value, str): + return json.loads(value) + else: + return json.loads(json.dumps(value)) def to_text(self) -> str: return json.dumps(self.value) - - def __add__(self, other: BaseArtifact) -> JsonArtifact: - raise NotImplementedError diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 298f29c6ad..0e6f81ca5d 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -1,23 +1,37 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Generic, Optional, TypeVar from attrs import Attribute, define, field from griptape.artifacts import BaseArtifact if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, Sequence + +T = TypeVar("T", bound=BaseArtifact, covariant=True) @define -class ListArtifact(BaseArtifact): - value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) +class ListArtifact(BaseArtifact, Generic[T]): + value: Sequence[T] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + def __getitem__(self, key: int) -> T: + return self.value[key] + + def __bool__(self) -> bool: + return len(self) > 0 + + def __add__(self, other: BaseArtifact) -> ListArtifact[T]: + return ListArtifact(self.value + other.value) + + def __iter__(self) -> Iterator[T]: + return iter(self.value) + @value.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None: + def validate_value(self, _: Attribute, value: list[T]) -> None: if self.validate_uniform_types and len(value) > 0: first_type = type(value[0]) @@ -31,18 +45,9 @@ def child_type(self) -> Optional[type]: else: return None - def __getitem__(self, key: int) -> BaseArtifact: - return self.value[key] - - def __bool__(self) -> bool: - return len(self) > 0 - def to_text(self) -> str: return self.item_separator.join([v.to_text() for v in self.value]) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - return ListArtifact(self.value + other.value) - def is_type(self, target_type: type) -> bool: if self.value: return isinstance(self.value[0], target_type) diff --git a/griptape/artifacts/media_artifact.py b/griptape/artifacts/media_artifact.py deleted file mode 100644 index a57217fc74..0000000000 --- a/griptape/artifacts/media_artifact.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import base64 -import random -import string -import time -from typing import Optional - -from attrs import define, field - -from griptape.artifacts import BlobArtifact - - -@define -class MediaArtifact(BlobArtifact): - """MediaArtifact is a type of BlobArtifact that represents media (image, audio, video, etc.) and can be extended to support a specific media type. - - Attributes: - value: Raw bytes representing media data. - media_type: The type of media, like image, audio, or video. - format: The format of the media, like png, wav, or mp4. - name: Artifact name, generated using creation time and a random string. - model: Optionally specify the model used to generate the media. - prompt: Optionally specify the prompt used to generate the media. - """ - - media_type: str = field(default="media", kw_only=True, metadata={"serializable": True}) - format: str = field(kw_only=True, metadata={"serializable": True}) - model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - - def __attrs_post_init__(self) -> None: - # Generating the name string requires attributes set by child classes. - # This waits until all attributes are available before generating a name. - if self.name == self.id: - self.name = self.make_name() - - @property - def mime_type(self) -> str: - return f"{self.media_type}/{self.format}" - - @property - def base64(self) -> str: - return base64.b64encode(self.value).decode("utf-8") - - def to_text(self) -> str: - return f"Media, type: {self.mime_type}, size: {len(self.value)} bytes" - - def make_name(self) -> str: - entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) - fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime()) - - return f"{self.media_type}_artifact_{fmt_time}_{entropy}.{self.format}" diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 752f666155..9623c80969 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -14,13 +14,7 @@ @define class TextArtifact(BaseArtifact): value: str = field(converter=str, metadata={"serializable": True}) - encoding: str = field(default="utf-8", kw_only=True) - encoding_error_handler: str = field(default="strict", kw_only=True) - _embedding: list[float] = field(factory=list, kw_only=True) - - @property - def embedding(self) -> Optional[list[float]]: - return None if len(self._embedding) == 0 else self._embedding + embedding: Optional[list[float]] = field(default=None, kw_only=True) def __add__(self, other: BaseArtifact) -> TextArtifact: return TextArtifact(self.value + other.value) @@ -28,14 +22,18 @@ def __add__(self, other: BaseArtifact) -> TextArtifact: def __bool__(self) -> bool: return bool(self.value.strip()) - def generate_embedding(self, driver: BaseEmbeddingDriver) -> Optional[list[float]]: - self._embedding.clear() - self._embedding.extend(driver.embed_string(str(self.value))) + def to_text(self) -> str: + return self.value + + def generate_embedding(self, driver: BaseEmbeddingDriver) -> list[float]: + embedding = driver.embed_string(str(self.value)) + + if self.embedding is None: + self.embedding = [] + self.embedding.clear() + self.embedding.extend(embedding) return self.embedding def token_count(self, tokenizer: BaseTokenizer) -> int: return tokenizer.count_tokens(str(self.value)) - - def to_bytes(self) -> bytes: - return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 6d8dfde75f..77ce4ba9b4 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -7,7 +7,6 @@ from griptape.artifacts import ( ActionArtifact, BaseArtifact, - ErrorArtifact, GenericArtifact, ImageArtifact, ListArtifact, @@ -70,8 +69,6 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage return [ImageMessageContent(artifact)] elif isinstance(artifact, GenericArtifact): return [GenericMessageContent(artifact)] - elif isinstance(artifact, ErrorArtifact): - return [TextMessageContent(TextArtifact(artifact.to_text()))] elif isinstance(artifact, ActionArtifact): action = artifact.value output = action.output @@ -81,6 +78,7 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage return [ActionResultMessageContent(output, action=action)] elif isinstance(artifact, ListArtifact): processed_contents = [self.__to_message_content(artifact) for artifact in artifact.value] + return [sub_content for processed_content in processed_contents for sub_content in processed_content] else: - raise ValueError(f"Unsupported artifact type: {type(artifact)}") + return [TextMessageContent(TextArtifact(artifact.to_text()))] diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index 7106c81926..4db302f6f6 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -46,12 +46,11 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=self.image_width, height=self.image_height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_variation( @@ -70,12 +69,11 @@ def try_image_variation( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_inpainting( @@ -96,12 +94,11 @@ def try_image_inpainting( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_outpainting( @@ -122,12 +119,11 @@ def try_image_outpainting( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def _make_request(self, request: dict) -> bytes: diff --git a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py index 46dbcd331c..b89df1c4b5 100644 --- a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py +++ b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py @@ -44,7 +44,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ format=self.output_format.lower(), height=output_image.height, width=output_image.width, - prompt=prompt, + meta={"prompt": prompt}, ) def try_image_variation( @@ -76,7 +76,7 @@ def try_image_variation( format=self.output_format.lower(), height=output_image.height, width=output_image.width, - prompt=prompt, + meta={"prompt": prompt}, ) def try_image_inpainting( diff --git a/griptape/drivers/image_generation/leonardo_image_generation_driver.py b/griptape/drivers/image_generation/leonardo_image_generation_driver.py index e32dbb4c72..db89244bf5 100644 --- a/griptape/drivers/image_generation/leonardo_image_generation_driver.py +++ b/griptape/drivers/image_generation/leonardo_image_generation_driver.py @@ -60,8 +60,10 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ format="png", width=self.image_width, height=self.image_height, - model=self.model, - prompt=", ".join(prompts), + meta={ + "model": self.model, + "prompt": ", ".join(prompts), + }, ) def try_image_variation( @@ -87,8 +89,10 @@ def try_image_variation( format="png", width=self.image_width, height=self.image_height, - model=self.model, - prompt=", ".join(prompts), + meta={ + "model": self.model, + "prompt": ", ".join(prompts), + }, ) def try_image_outpainting( diff --git a/griptape/drivers/image_generation/openai_image_generation_driver.py b/griptape/drivers/image_generation/openai_image_generation_driver.py index 0ee50a1e2c..bf77ac300b 100644 --- a/griptape/drivers/image_generation/openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/openai_image_generation_driver.py @@ -151,6 +151,5 @@ def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageA format="png", width=image_dimensions[0], height=image_dimensions[1], - model=self.model, - prompt=prompt, + meta={"model": self.model, "prompt": prompt}, ) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index b45bdf7f5e..e7be73c73b 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,11 +2,11 @@ import csv import io -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import Factory, define, field -from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -20,6 +20,9 @@ class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(default=Factory(list), kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) def extract( self, @@ -32,26 +35,27 @@ def extract( self._extract_rec( cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], [], + rulesets=rulesets, ), item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: + def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row])))) + rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row]))))) return rows def _extract_rec( self, artifacts: list[TextArtifact], - rows: list[CsvRowArtifact], + rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[CsvRowArtifact]: + ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( column_names=self.column_names, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 14dfe3e4a6..bcf7029d4a 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,11 +2,11 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -18,8 +18,11 @@ class CsvLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: artifacts = [] if isinstance(source, bytes): @@ -28,7 +31,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [CsvRowArtifact(row) for row in reader] + chunks = [TextArtifact(self.formatter_fn(row)) for row in reader] if self.embedding_driver: for chunk in chunks: @@ -44,8 +47,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[CsvRowArtifact]]: + ) -> dict[str, list[TextArtifact]]: return cast( - dict[str, list[CsvRowArtifact]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 0b1ae14484..30d705676f 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -18,11 +18,14 @@ @define class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: artifacts = [] - chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] + chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")] if self.embedding_driver: for chunk in chunks: @@ -33,8 +36,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index e4522796ff..105f585cb3 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -15,12 +15,15 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] - chunks = [CsvRowArtifact(row.cells) for row in rows] if rows else [] + chunks = [TextArtifact(self.formatter_fn(row.cells)) for row in rows] if rows else [] if self.embedding_driver: for chunk in chunks: @@ -31,5 +34,5 @@ def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/mixins/media_artifact_file_output_mixin.py b/griptape/mixins/artifact_file_output_mixin.py similarity index 87% rename from griptape/mixins/media_artifact_file_output_mixin.py rename to griptape/mixins/artifact_file_output_mixin.py index 9b9f349115..25ed8718d7 100644 --- a/griptape/mixins/media_artifact_file_output_mixin.py +++ b/griptape/mixins/artifact_file_output_mixin.py @@ -7,11 +7,11 @@ from attrs import Attribute, define, field if TYPE_CHECKING: - from griptape.artifacts import BlobArtifact + from griptape.artifacts import BaseArtifact @define(slots=False) -class BlobArtifactFileOutputMixin: +class ArtifactFileOutputMixin: output_dir: Optional[str] = field(default=None, kw_only=True) output_file: Optional[str] = field(default=None, kw_only=True) @@ -31,7 +31,7 @@ def validate_output_file(self, _: Attribute, output_file: str) -> None: if self.output_dir: raise ValueError("Can't have both output_dir and output_file specified.") - def _write_to_file(self, artifact: BlobArtifact) -> None: + def _write_to_file(self, artifact: BaseArtifact) -> None: if self.output_file: outfile = self.output_file elif self.output_dir: @@ -42,4 +42,4 @@ def _write_to_file(self, artifact: BlobArtifact) -> None: if os.path.dirname(outfile): os.makedirs(os.path.dirname(outfile), exist_ok=True) - Path(outfile).write_bytes(artifact.value) + Path(outfile).write_bytes(artifact.to_bytes()) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 9290c60981..dde3ae49a8 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -2,7 +2,7 @@ from abc import ABC from collections.abc import Sequence -from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin +from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin import attrs from marshmallow import INCLUDE, Schema, fields @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested: field_class, args, optional = cls._get_field_type_info(field_type) + # Resolve TypeVars to their bound type + if isinstance(field_class, TypeVar): + field_class = field_class.__bound__ + if attrs.has(field_class): if ABC in field_class.__bases__: return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional) diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index fae217d54e..91f7b7501a 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -6,7 +6,7 @@ from attrs import define from griptape.configs import Defaults -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask @@ -14,7 +14,7 @@ @define -class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): +class BaseAudioGenerationTask(ArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): def before_run(self) -> None: super().before_run() diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index bd36d00803..326b2a5514 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -10,20 +10,19 @@ from griptape.configs import Defaults from griptape.loaders import ImageLoader -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Rule, Ruleset from griptape.tasks import BaseTask if TYPE_CHECKING: - from griptape.artifacts import MediaArtifact - + from griptape.artifacts import ImageArtifact logger = logging.getLogger(Defaults.logging_config.logger_name) @define -class BaseImageGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): +class BaseImageGenerationTask(ArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): """Provides a base class for image generation-related tasks. Attributes: @@ -65,6 +64,6 @@ def all_negative_rulesets(self) -> list[Ruleset]: return task_rulesets - def _read_from_file(self, path: str) -> MediaArtifact: + def _read_from_file(self, path: str) -> ImageArtifact: logger.info("Reading image from %s", os.path.abspath(path)) return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 2dcb796d83..7ae63b9027 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -84,7 +84,11 @@ def run(self) -> BaseArtifact: subtask.after_run() if isinstance(subtask.output, ListArtifact): - self.output = subtask.output[0] + first_artifact = subtask.output[0] + if isinstance(first_artifact, BaseArtifact): + self.output = first_artifact + else: + raise ValueError(f"Output is not an Artifact: {type(first_artifact)}") else: self.output = InfoArtifact("No tool output") except Exception as e: diff --git a/griptape/tools/base_image_generation_tool.py b/griptape/tools/base_image_generation_tool.py index ee1c37b6df..2df5d97479 100644 --- a/griptape/tools/base_image_generation_tool.py +++ b/griptape/tools/base_image_generation_tool.py @@ -1,11 +1,11 @@ from attrs import define -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.tools import BaseTool @define -class BaseImageGenerationTool(BlobArtifactFileOutputMixin, BaseTool): +class BaseImageGenerationTool(ArtifactFileOutputMixin, BaseTool): """A base class for tools that generate images from text prompts.""" PROMPT_DESCRIPTION = "Features and qualities to include in the generated image, descriptive and succinct." diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 0089970e91..0274e7940d 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from schema import Literal, Or, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact from griptape.configs import Defaults from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -60,7 +60,7 @@ class QueryTool(BaseTool, RuleMixin): ), }, ) - def query(self, params: dict) -> BaseArtifact: + def query(self, params: dict) -> ListArtifact | ErrorArtifact: query = params["values"]["query"] content = params["values"]["content"] diff --git a/griptape/tools/text_to_speech/tool.py b/griptape/tools/text_to_speech/tool.py index ea49820296..aca259698f 100644 --- a/griptape/tools/text_to_speech/tool.py +++ b/griptape/tools/text_to_speech/tool.py @@ -5,7 +5,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -15,7 +15,7 @@ @define -class TextToSpeechTool(BlobArtifactFileOutputMixin, BaseTool): +class TextToSpeechTool(ArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate speech from input text. Attributes: diff --git a/tests/unit/artifacts/test_action_artifact.py b/tests/unit/artifacts/test_action_artifact.py index 2530ed8c36..b7180b1c3d 100644 --- a/tests/unit/artifacts/test_action_artifact.py +++ b/tests/unit/artifacts/test_action_artifact.py @@ -11,10 +11,6 @@ class TestActionArtifact: def action(self) -> ToolAction: return ToolAction(tag="TestTag", name="TestName", path="TestPath", input={"foo": "bar"}) - def test___add__(self, action): - with pytest.raises(NotImplementedError): - ActionArtifact(action) + ActionArtifact(action) - def test_to_text(self, action): assert ActionArtifact(action).to_text() == json.dumps(action.to_dict()) diff --git a/tests/unit/artifacts/test_audio_artifact.py b/tests/unit/artifacts/test_audio_artifact.py index 6d44c05b3e..aab6af6305 100644 --- a/tests/unit/artifacts/test_audio_artifact.py +++ b/tests/unit/artifacts/test_audio_artifact.py @@ -6,20 +6,22 @@ class TestAudioArtifact: @pytest.fixture() def audio_artifact(self): - return AudioArtifact(value=b"some binary audio data", format="pcm", model="provider/model", prompt="two words") + return AudioArtifact( + value=b"some binary audio data", format="pcm", meta={"model": "provider/model", "prompt": "two words"} + ) def test_mime_type(self, audio_artifact: AudioArtifact): assert audio_artifact.mime_type == "audio/pcm" def test_to_text(self, audio_artifact: AudioArtifact): - assert audio_artifact.to_text() == "Media, type: audio/pcm, size: 22 bytes" + assert audio_artifact.to_text() == "Audio, format: pcm, size: 22 bytes" def test_to_dict(self, audio_artifact: AudioArtifact): audio_dict = audio_artifact.to_dict() assert audio_dict["format"] == "pcm" - assert audio_dict["model"] == "provider/model" - assert audio_dict["prompt"] == "two words" + assert audio_dict["meta"]["model"] == "provider/model" + assert audio_dict["meta"]["prompt"] == "two words" assert audio_dict["value"] == "c29tZSBiaW5hcnkgYXVkaW8gZGF0YQ==" def test_deserialization(self, audio_artifact): @@ -31,5 +33,5 @@ def test_deserialization(self, audio_artifact): assert deserialized_artifact.value == b"some binary audio data" assert deserialized_artifact.mime_type == "audio/pcm" assert deserialized_artifact.format == "pcm" - assert deserialized_artifact.model == "provider/model" - assert deserialized_artifact.prompt == "two words" + assert deserialized_artifact.meta["model"] == "provider/model" + assert deserialized_artifact.meta["prompt"] == "two words" diff --git a/tests/unit/artifacts/test_base_artifact.py b/tests/unit/artifacts/test_base_artifact.py index 6cf8f4466f..28e2761a82 100644 --- a/tests/unit/artifacts/test_base_artifact.py +++ b/tests/unit/artifacts/test_base_artifact.py @@ -41,7 +41,7 @@ def test_list_artifact_from_dict(self): assert artifact.to_text() == "foobar" def test_blob_artifact_from_dict(self): - dict_value = {"type": "BlobArtifact", "value": b"Zm9vYmFy", "dir_name": "foo", "name": "bar"} + dict_value = {"type": "BlobArtifact", "value": b"Zm9vYmFy", "name": "bar"} artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, BlobArtifact) @@ -51,17 +51,15 @@ def test_image_artifact_from_dict(self): dict_value = { "type": "ImageArtifact", "value": b"aW1hZ2UgZGF0YQ==", - "dir_name": "foo", "format": "png", "width": 256, "height": 256, - "model": "test-model", - "prompt": "some prompt", + "meta": {"model": "test-model", "prompt": "some prompt"}, } artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, ImageArtifact) - assert artifact.to_text() == "Media, type: image/png, size: 10 bytes" + assert artifact.to_text() == "Image, format: png, size: 10 bytes" assert artifact.value == b"image data" def test_unsupported_from_dict(self): diff --git a/tests/unit/artifacts/test_base_media_artifact.py b/tests/unit/artifacts/test_base_media_artifact.py deleted file mode 100644 index c85d070fef..0000000000 --- a/tests/unit/artifacts/test_base_media_artifact.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from attrs import define - -from griptape.artifacts import MediaArtifact - - -class TestMediaArtifact: - @define - class ImaginaryMediaArtifact(MediaArtifact): - media_type: str = "imagination" - - @pytest.fixture() - def media_artifact(self): - return self.ImaginaryMediaArtifact(value=b"some binary dream data", format="dream") - - def test_to_dict(self, media_artifact): - image_dict = media_artifact.to_dict() - - assert image_dict["format"] == "dream" - assert image_dict["value"] == "c29tZSBiaW5hcnkgZHJlYW0gZGF0YQ==" - - def test_name(self, media_artifact): - assert media_artifact.name.startswith("imagination_artifact") - assert media_artifact.name.endswith(".dream") - - def test_mime_type(self, media_artifact): - assert media_artifact.mime_type == "imagination/dream" - - def test_to_text(self, media_artifact): - assert media_artifact.to_text() == "Media, type: imagination/dream, size: 22 bytes" diff --git a/tests/unit/artifacts/test_blob_artifact.py b/tests/unit/artifacts/test_blob_artifact.py index 3d88d57934..9db5e21f5d 100644 --- a/tests/unit/artifacts/test_blob_artifact.py +++ b/tests/unit/artifacts/test_blob_artifact.py @@ -1,5 +1,4 @@ import base64 -import os import pytest @@ -13,6 +12,9 @@ def test_value_type_conversion(self): def test_to_text(self): assert BlobArtifact(b"foobar", name="foobar.txt").to_text() == "foobar" + def test_to_bytes(self): + assert BlobArtifact(b"foo").to_bytes() == b"foo" + def test_to_text_encoding(self): assert ( BlobArtifact("ß".encode("ascii", errors="backslashreplace"), name="foobar.txt", encoding="ascii").to_text() @@ -30,37 +32,34 @@ def test_to_text_encoding_error_handler(self): ) def test_to_dict(self): - assert BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo").to_dict()["name"] == "foobar.txt" - - def test_full_path_with_path(self): - assert BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo").full_path == os.path.normpath( - "foo/foobar.txt" - ) - - def test_full_path_without_path(self): - assert BlobArtifact(b"foobar", name="foobar.txt").full_path == "foobar.txt" + assert BlobArtifact(b"foobar", name="foobar.txt").to_dict()["name"] == "foobar.txt" def test_serialization(self): - artifact = BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo") + artifact = BlobArtifact(b"foobar", name="foobar.txt") artifact_dict = artifact.to_dict() assert artifact_dict["name"] == "foobar.txt" - assert artifact_dict["dir_name"] == "foo" assert base64.b64decode(artifact_dict["value"]) == b"foobar" def test_deserialization(self): - artifact = BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo") + artifact = BlobArtifact(b"foobar", name="foobar.txt") artifact_dict = artifact.to_dict() deserialized_artifact = BaseArtifact.from_dict(artifact_dict) assert isinstance(deserialized_artifact, BlobArtifact) assert deserialized_artifact.name == "foobar.txt" - assert deserialized_artifact.dir_name == "foo" assert deserialized_artifact.value == b"foobar" def test_name(self): assert BlobArtifact(b"foo", name="bar").name == "bar" + def test_mime_type(self): + assert BlobArtifact(b"foo").mime_type == "application/octet-stream" + + def test___add__(self): + with pytest.raises(TypeError): + BlobArtifact(b"foo") + BlobArtifact(b"bar") + def test___bool__(self): assert not bool(BlobArtifact(b"")) assert bool(BlobArtifact(b"foo")) diff --git a/tests/unit/artifacts/test_boolean_artifact.py b/tests/unit/artifacts/test_boolean_artifact.py index 57bbf16622..6ed21608d4 100644 --- a/tests/unit/artifacts/test_boolean_artifact.py +++ b/tests/unit/artifacts/test_boolean_artifact.py @@ -10,6 +10,7 @@ def test_parse_bool(self): assert BooleanArtifact.parse_bool("false").value is False assert BooleanArtifact.parse_bool("True").value is True assert BooleanArtifact.parse_bool("False").value is False + assert BooleanArtifact.parse_bool(True).value is True with pytest.raises(ValueError): BooleanArtifact.parse_bool("foo") @@ -35,3 +36,13 @@ def test_value_type_conversion(self): assert BooleanArtifact([]).value is False assert BooleanArtifact(False).value is False assert BooleanArtifact(True).value is True + + def test_to_text(self): + assert BooleanArtifact(True).to_text() == "true" + assert BooleanArtifact(False).to_text() == "false" + + def test__eq__(self): + assert BooleanArtifact(True) == BooleanArtifact(True) + assert BooleanArtifact(False) == BooleanArtifact(False) + assert BooleanArtifact(True) != BooleanArtifact(False) + assert BooleanArtifact(False) != BooleanArtifact(True) diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py deleted file mode 100644 index fe0b8cd64d..0000000000 --- a/tests/unit/artifacts/test_csv_row_artifact.py +++ /dev/null @@ -1,30 +0,0 @@ -from griptape.artifacts import CsvRowArtifact - - -class TestCsvRowArtifact: - def test_value_type_conversion(self): - assert CsvRowArtifact({"foo": "bar"}).value == {"foo": "bar"} - assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == {"foo": {"bar": "baz"}} - assert CsvRowArtifact('{"foo": "bar"}').value == {"foo": "bar"} - - def test___add__(self): - assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == { - "test1": "foo", - "test2": "bar", - } - - def test_to_text(self): - assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}, delimiter="|").to_text() == 'test1|test2\r\n"foo|bar"|1' - - def test_to_dict(self): - assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == {"test1": "foo"} - - def test_name(self): - artifact = CsvRowArtifact({}) - - assert artifact.name == artifact.id - assert CsvRowArtifact({}, name="bar").name == "bar" - - def test___bool__(self): - assert not bool(CsvRowArtifact({})) - assert bool(CsvRowArtifact({"foo": "bar"})) diff --git a/tests/unit/artifacts/test_image_artifact.py b/tests/unit/artifacts/test_image_artifact.py index a722ebd911..a632953ae8 100644 --- a/tests/unit/artifacts/test_image_artifact.py +++ b/tests/unit/artifacts/test_image_artifact.py @@ -11,12 +11,11 @@ def image_artifact(self): format="png", width=512, height=512, - model="openai/dalle2", - prompt="a cute cat", + meta={"model": "openai/dalle2", "prompt": "a cute cat"}, ) def test_to_text(self, image_artifact: ImageArtifact): - assert image_artifact.to_text() == "Media, type: image/png, size: 26 bytes" + assert image_artifact.to_text() == "Image, format: png, size: 26 bytes" def test_to_dict(self, image_artifact: ImageArtifact): image_dict = image_artifact.to_dict() @@ -24,8 +23,8 @@ def test_to_dict(self, image_artifact: ImageArtifact): assert image_dict["format"] == "png" assert image_dict["width"] == 512 assert image_dict["height"] == 512 - assert image_dict["model"] == "openai/dalle2" - assert image_dict["prompt"] == "a cute cat" + assert image_dict["meta"]["model"] == "openai/dalle2" + assert image_dict["meta"]["prompt"] == "a cute cat" assert image_dict["value"] == "c29tZSBiaW5hcnkgcG5nIGltYWdlIGRhdGE=" def test_deserialization(self, image_artifact): @@ -39,5 +38,5 @@ def test_deserialization(self, image_artifact): assert deserialized_artifact.format == "png" assert deserialized_artifact.width == 512 assert deserialized_artifact.height == 512 - assert deserialized_artifact.model == "openai/dalle2" - assert deserialized_artifact.prompt == "a cute cat" + assert deserialized_artifact.meta["model"] == "openai/dalle2" + assert deserialized_artifact.meta["prompt"] == "a cute cat" diff --git a/tests/unit/artifacts/test_json_artifact.py b/tests/unit/artifacts/test_json_artifact.py index 06f5d6297d..be61e3edfe 100644 --- a/tests/unit/artifacts/test_json_artifact.py +++ b/tests/unit/artifacts/test_json_artifact.py @@ -1,8 +1,6 @@ import json -import pytest - -from griptape.artifacts import JsonArtifact, TextArtifact +from griptape.artifacts import JsonArtifact class TestJsonArtifact: @@ -14,11 +12,11 @@ def test_value_type_conversion(self): assert JsonArtifact({"foo": None}).value == json.loads(json.dumps({"foo": None})) assert JsonArtifact([{"foo": {"bar": "baz"}}]).value == json.loads(json.dumps([{"foo": {"bar": "baz"}}])) assert JsonArtifact(None).value == json.loads(json.dumps(None)) - assert JsonArtifact("foo").value == json.loads(json.dumps("foo")) - - def test___add__(self): - with pytest.raises(NotImplementedError): - JsonArtifact({"foo": "bar"}) + TextArtifact("invalid json") + assert JsonArtifact('"foo"').value == "foo" + assert JsonArtifact("true").value is True + assert JsonArtifact("false").value is False + assert JsonArtifact("123").value == 123 + assert JsonArtifact("123.4").value == 123.4 def test_to_text(self): assert JsonArtifact({"foo": "bar"}).to_text() == json.dumps({"foo": "bar"}) diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index 06d2346458..0d6faaa7b4 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -1,6 +1,7 @@ import pytest -from griptape.artifacts import BlobArtifact, CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, ListArtifact, TextArtifact +from griptape.artifacts.image_artifact import ImageArtifact class TestListArtifact: @@ -23,6 +24,12 @@ def test___add__(self): assert artifact.value[0].value == "foo" assert artifact.value[1].value == "bar" + def test___iter__(self): + assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"] + + def test_type_var(self): + assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo" + def test_validate_value(self): with pytest.raises(ValueError): ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True) @@ -32,8 +39,7 @@ def test_child_type(self): def test_is_type(self): assert ListArtifact([TextArtifact("foo")]).is_type(TextArtifact) - assert ListArtifact([CsvRowArtifact({"foo": "bar"})]).is_type(TextArtifact) - assert ListArtifact([CsvRowArtifact({"foo": "bar"})]).is_type(CsvRowArtifact) + assert ListArtifact([ImageArtifact(b"", width=1234, height=1234, format="png")]).is_type(ImageArtifact) def test_has_items(self): assert not ListArtifact().has_items() diff --git a/tests/unit/artifacts/test_text_artifact.py b/tests/unit/artifacts/test_text_artifact.py index 9e00e2d2ed..dda256d27a 100644 --- a/tests/unit/artifacts/test_text_artifact.py +++ b/tests/unit/artifacts/test_text_artifact.py @@ -18,6 +18,11 @@ def test___add__(self): def test_to_dict(self): assert TextArtifact("foobar").to_dict()["value"] == "foobar" + def test_to_bytes(self): + artifact = TextArtifact("foobar") + + assert artifact.to_bytes() == b"foobar" + def test_from_dict(self): assert BaseArtifact.from_dict(TextArtifact("foobar").to_dict()).value == "foobar" diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index 9aa4d3f4f2..05e669b661 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -60,5 +60,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "stability.stable-diffusion-xl-v1" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "stability.stable-diffusion-xl-v1" + assert image_artifact.meta["prompt"] == "test prompt" diff --git a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py index 268708b2b7..a727642111 100644 --- a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py @@ -28,7 +28,10 @@ def test_init(self, driver): def test_init_requires_endpoint(self): with pytest.raises(TypeError): AzureOpenAiImageGenerationDriver( - model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512" + model="dall-e-3", + client=Mock(), + azure_deployment="dalle-deployment", + image_size="512x512", ) # pyright: ignore[reportCallIssues] def test_try_text_to_image(self, driver): @@ -40,5 +43,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-3" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-3" + assert image_artifact.meta["prompt"] == "test prompt" diff --git a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py index 48805cde63..ec70e2dd2f 100644 --- a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py @@ -76,5 +76,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "test_model_id" - assert image_artifact.prompt == "test_prompt" + assert image_artifact.meta["model"] == "test_model_id" + assert image_artifact.meta["prompt"] == "test_prompt" diff --git a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py index 16bcd28701..ff5528fb62 100644 --- a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py @@ -22,8 +22,8 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-2" + assert image_artifact.meta["prompt"] == "test prompt" def test_try_image_variation(self, driver): driver.client.images.create_variation.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")]) @@ -34,7 +34,7 @@ def test_try_image_variation(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" + assert image_artifact.meta["model"] == "dall-e-2" def test_try_image_variation_invalid_size(self, driver): driver.image_size = "1024x1792" @@ -59,8 +59,8 @@ def test_try_image_inpainting(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-2" + assert image_artifact.meta["prompt"] == "test prompt" def test_try_image_inpainting_invalid_size(self, driver): driver.image_size = "1024x1792" diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py index ac4ff80437..20a3e2b50c 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py @@ -4,7 +4,6 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.artifacts.csv_row_artifact import CsvRowArtifact class BaseLocalVectorStoreDriver(ABC): @@ -26,20 +25,6 @@ def test_upsert(self, driver): assert len(driver.entries) == 2 - def test_upsert_csv_row(self, driver): - namespace = driver.upsert_text_artifact(CsvRowArtifact(id="foo1", value={"col": "value"})) - - assert len(driver.entries) == 1 - assert list(driver.entries.keys())[0] == namespace - - driver.upsert_text_artifact(CsvRowArtifact(id="foo1", value={"col": "value"})) - - assert len(driver.entries) == 1 - - driver.upsert_text_artifact(CsvRowArtifact(id="foo2", value={"col": "value2"})) - - assert len(driver.entries) == 2 - def test_upsert_multiple(self, driver): driver.upsert_text_artifacts({"foo": [TextArtifact("foo")], "bar": [TextArtifact("bar")]}) diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index 893c21d60f..056df2d5aa 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -12,11 +12,11 @@ def test_extract(self, engine): result = engine.extract("foo") assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" def test_text_to_csv_rows(self, engine): result = engine.text_to_csv_rows("foo,bar\nbaz,maz", ["test1", "test2"]) assert len(result) == 2 - assert result[0].value == {"test1": "foo", "test2": "bar"} - assert result[1].value == {"test1": "baz", "test2": "maz"} + assert result[0].value == "test1: foo\ntest2: bar" + assert result[1].value == "test1: baz\ntest2: maz" diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index 473fd0d9e2..b7ebdd9120 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -13,14 +13,13 @@ def loader(self): def create_source(self, bytes_from_resource_path): return bytes_from_resource_path - @pytest.mark.parametrize(("resource_path", "suffix", "mime_type"), [("sentences.wav", ".wav", "audio/wav")]) - def test_load(self, resource_path, suffix, mime_type, loader, create_source): + @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) + def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) artifact = loader.load(source) assert isinstance(artifact, AudioArtifact) - assert artifact.name.endswith(suffix) assert artifact.mime_type == mime_type assert len(artifact.value) > 0 @@ -35,6 +34,5 @@ def test_load_collection(self, create_source, loader): for key in collection: artifact = collection[key] assert isinstance(artifact, AudioArtifact) - assert artifact.name.endswith(".wav") assert artifact.mime_type == "audio/wav" assert len(artifact.value) > 0 diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index a747afff71..7af409152e 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,3 +1,5 @@ +import json + import pytest from griptape.loaders.csv_loader import CsvLoader @@ -28,8 +30,7 @@ def test_load(self, loader, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "foo1" - assert first_artifact.value["Bar"] == "bar1" + assert first_artifact.value == "Foo: foo1\nBar: bar1" assert first_artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): @@ -39,8 +40,7 @@ def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "bar1" - assert first_artifact.value["Bar"] == "foo1" + assert first_artifact.value == "Bar: foo1\nFoo: bar1" assert first_artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): @@ -52,10 +52,17 @@ def test_load_collection(self, loader, create_source): keys = {loader.to_key(source) for source in sources} assert collection.keys() == keys - for key in keys: - artifacts = collection[key] - assert len(artifacts) == 10 - first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "foo1" - assert first_artifact.value["Bar"] == "bar1" - assert first_artifact.embedding == [0, 1] + assert collection[loader.to_key(sources[0])][0].value == "Foo: foo1\nBar: bar1" + assert collection[loader.to_key(sources[0])][0].embedding == [0, 1] + + assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" + assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] + + def test_formatter_fn(self, loader, create_source): + loader.formatter_fn = lambda value: json.dumps(value) + source = create_source("test-1.csv") + + artifacts = loader.load(source) + + assert len(artifacts) == 10 + assert artifacts[0].value == '{"Foo": "foo1", "Bar": "bar1"}' diff --git a/tests/unit/loaders/test_dataframe_loader.py b/tests/unit/loaders/test_dataframe_loader.py deleted file mode 100644 index 5c2a57ed6a..0000000000 --- a/tests/unit/loaders/test_dataframe_loader.py +++ /dev/null @@ -1,52 +0,0 @@ -import os - -import pandas as pd -import pytest - -from griptape.loaders.dataframe_loader import DataFrameLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - - -class TestDataFrameLoader: - @pytest.fixture() - def loader(self): - return DataFrameLoader(embedding_driver=MockEmbeddingDriver()) - - def test_load_with_path(self, loader): - # test loading a file delimited by comma - path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - - artifacts = loader.load(pd.read_csv(path)) - - assert len(artifacts) == 10 - first_artifact = artifacts[0].value - assert first_artifact["Foo"] == "foo1" - assert first_artifact["Bar"] == "bar1" - - assert artifacts[0].embedding == [0, 1] - - def test_load_collection_with_path(self, loader): - path1 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - path2 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-2.csv") - df1 = pd.read_csv(path1) - df2 = pd.read_csv(path2) - collection = loader.load_collection([df1, df2]) - - key1 = loader.to_key(df1) - key2 = loader.to_key(df2) - - assert list(collection.keys()) == [key1, key2] - - artifacts = collection[key1] - assert len(artifacts) == 10 - first_artifact = artifacts[0].value - assert first_artifact["Foo"] == "foo1" - assert first_artifact["Bar"] == "bar1" - - artifacts = collection[key2] - assert len(artifacts) == 10 - first_artifact = artifacts[0].value - assert first_artifact["Bar"] == "bar1" - assert first_artifact["Foo"] == "foo1" - - assert artifacts[0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index eca4cbccc5..7093894b00 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -18,23 +18,22 @@ def create_source(self, bytes_from_resource_path): return bytes_from_resource_path @pytest.mark.parametrize( - ("resource_path", "suffix", "mime_type"), + ("resource_path", "mime_type"), [ - ("small.png", ".png", "image/png"), - ("small.jpg", ".jpeg", "image/jpeg"), - ("small.webp", ".webp", "image/webp"), - ("small.bmp", ".bmp", "image/bmp"), - ("small.gif", ".gif", "image/gif"), - ("small.tiff", ".tiff", "image/tiff"), + ("small.png", "image/png"), + ("small.jpg", "image/jpeg"), + ("small.webp", "image/webp"), + ("small.bmp", "image/bmp"), + ("small.gif", "image/gif"), + ("small.tiff", "image/tiff"), ], ) - def test_load(self, resource_path, suffix, mime_type, loader, create_source): + def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) artifact = loader.load(source) assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(suffix) assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == mime_type @@ -49,7 +48,6 @@ def test_load_normalize(self, resource_path, png_loader, create_source): artifact = png_loader.load(source) assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(".png") assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == "image/png" @@ -68,7 +66,6 @@ def test_load_collection(self, create_source, png_loader): for key in keys: artifact = collection[key] assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(".png") assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == "image/png" diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index fbfa6d4fa9..2ff6c7fafd 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -38,24 +38,21 @@ def test_load(self, loader): artifacts = loader.load("SELECT * FROM test_table;") assert len(artifacts) == 3 - assert artifacts[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} - assert artifacts[1].value == {"id": 2, "name": "Bob", "age": 30, "city": "Los Angeles"} - assert artifacts[2].value == {"id": 3, "name": "Charlie", "age": 22, "city": "Chicago"} + assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" + assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" assert artifacts[0].embedding == [0, 1] def test_load_collection(self, loader): - artifacts = loader.load_collection(["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"]) + sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"] + artifacts = loader.load_collection(sources) assert list(artifacts.keys()) == [ loader.to_key("SELECT * FROM test_table LIMIT 1;"), loader.to_key("SELECT * FROM test_table LIMIT 2;"), ] - assert [a.value for artifact_list in artifacts.values() for a in artifact_list] == [ - {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, - {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, - {"age": 30, "city": "Los Angeles", "id": 2, "name": "Bob"}, - ] - + assert artifacts[loader.to_key(sources[0])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifacts[loader.to_key(sources[1])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" assert list(artifacts.values())[0][0].embedding == [0, 1] diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index 2f6ffe1c91..d2575959a2 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -1,6 +1,6 @@ import pytest -from griptape.artifacts import BlobArtifact, CsvRowArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.memory import TaskMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.structures import Agent @@ -10,10 +10,6 @@ class TestTaskMemory: - @pytest.fixture(autouse=True) - def _mock_griptape(self, mocker): - mocker.patch("griptape.engines.CsvExtractionEngine.extract", return_value=[CsvRowArtifact({"foo": "bar"})]) - @pytest.fixture() def memory(self): return defaults.text_task_memory("MyMemory") diff --git a/tests/unit/mixins/test_image_artifact_file_output_mixin.py b/tests/unit/mixins/test_image_artifact_file_output_mixin.py index cf124da395..7e2926e093 100644 --- a/tests/unit/mixins/test_image_artifact_file_output_mixin.py +++ b/tests/unit/mixins/test_image_artifact_file_output_mixin.py @@ -4,12 +4,12 @@ import pytest from griptape.artifacts import ImageArtifact -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin -class TestMediaArtifactFileOutputMixin: +class TestArtifactFileOutputMixin: def test_no_output(self): - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): pass assert Test().output_file is None @@ -18,7 +18,7 @@ class Test(BlobArtifactFileOutputMixin): def test_output_file(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): def run(self) -> None: self._write_to_file(artifact) @@ -33,7 +33,7 @@ def run(self) -> None: def test_output_dir(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): def run(self) -> None: self._write_to_file(artifact) @@ -46,7 +46,7 @@ def run(self) -> None: assert os.path.exists(os.path.join(outdir, artifact.name)) def test_output_file_and_dir(self): - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): pass outfile = "test.txt" diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index 2d7ab442c8..06d444f9b4 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -18,4 +18,4 @@ def test_run(self, task): result = task.run() assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" diff --git a/tests/unit/tools/test_extraction_tool.py b/tests/unit/tools/test_extraction_tool.py index 1219da373b..3f783e8a49 100644 --- a/tests/unit/tools/test_extraction_tool.py +++ b/tests/unit/tools/test_extraction_tool.py @@ -58,10 +58,10 @@ def test_csv_extract_artifacts(self, csv_tool): ) assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" def test_csv_extract_content(self, csv_tool): result = csv_tool.extract({"values": {"data": "foo"}}) assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 569c0a2800..469918a025 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -5,7 +5,7 @@ import pytest -from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader from griptape.tools import FileManagerTool @@ -106,29 +106,6 @@ def test_save_memory_artifacts_to_disk_for_multiple_artifacts(self, temp_dir): assert Path(os.path.join(temp_dir, "test", f"{artifacts[1].name}-{file_name}")).read_text() == "baz" assert result.value == "Successfully saved memory artifacts to disk" - def test_save_memory_artifacts_to_disk_for_non_string_artifact(self, temp_dir): - memory = defaults.text_task_memory("Memory1") - artifact = CsvRowArtifact({"foo": "bar"}) - - memory.store_artifact("foobar", artifact) - - file_manager = FileManagerTool( - input_memory=[memory], file_manager_driver=LocalFileManagerDriver(workdir=temp_dir) - ) - result = file_manager.save_memory_artifacts_to_disk( - { - "values": { - "dir_name": "test", - "file_name": "foobar.txt", - "memory_name": memory.name, - "artifact_namespace": "foobar", - } - } - ) - - assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foo\nbar" - assert result.value == "Successfully saved memory artifacts to disk" - def test_save_content_to_file(self, temp_dir): file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(workdir=temp_dir)) result = file_manager.save_content_to_file( diff --git a/tests/unit/tools/test_inpainting_image_generation_tool.py b/tests/unit/tools/test_inpainting_image_generation_tool.py index 45afcbc63a..a558921a94 100644 --- a/tests/unit/tools/test_inpainting_image_generation_tool.py +++ b/tests/unit/tools/test_inpainting_image_generation_tool.py @@ -59,8 +59,8 @@ def test_image_inpainting_with_outfile( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_inpainting_from_file( @@ -83,8 +83,8 @@ def test_image_inpainting_from_memory(self, image_generation_engine, image_artif memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_inpainting_from_memory( diff --git a/tests/unit/tools/test_outpainting_image_variation_tool.py b/tests/unit/tools/test_outpainting_image_variation_tool.py index 4fbcbe8d49..e3f0de847b 100644 --- a/tests/unit/tools/test_outpainting_image_variation_tool.py +++ b/tests/unit/tools/test_outpainting_image_variation_tool.py @@ -34,8 +34,8 @@ def test_validate_output_configs(self, image_generation_engine) -> None: OutpaintingImageGenerationTool(engine=image_generation_engine, output_dir="test", output_file="test") def test_image_outpainting(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = Mock( - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_outpainting_from_file( @@ -59,8 +59,8 @@ def test_image_outpainting_with_outfile( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_outpainting_from_file( diff --git a/tests/unit/tools/test_prompt_image_generation_tool.py b/tests/unit/tools/test_prompt_image_generation_tool.py index a0c5c7037e..4252d887e9 100644 --- a/tests/unit/tools/test_prompt_image_generation_tool.py +++ b/tests/unit/tools/test_prompt_image_generation_tool.py @@ -5,6 +5,7 @@ import pytest +from griptape.artifacts.image_artifact import ImageArtifact from griptape.tools import PromptImageGenerationTool @@ -36,8 +37,8 @@ def test_generate_image_with_outfile(self, image_generation_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = PromptImageGenerationTool(engine=image_generation_engine, output_file=outfile) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.generate_image( diff --git a/tests/unit/tools/test_sql_tool.py b/tests/unit/tools/test_sql_tool.py index 2ef50ff549..061b31f4dc 100644 --- a/tests/unit/tools/test_sql_tool.py +++ b/tests/unit/tools/test_sql_tool.py @@ -26,7 +26,7 @@ def test_execute_query(self, driver): result = client.execute_query({"values": {"sql_query": "SELECT * from test_table;"}}) assert len(result.value) == 1 - assert result.value[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} + assert result.value[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" def test_execute_query_description(self, driver): client = SqlTool( diff --git a/tests/unit/tools/test_text_to_speech_tool.py b/tests/unit/tools/test_text_to_speech_tool.py index 8821d48fc3..6f2c43bd39 100644 --- a/tests/unit/tools/test_text_to_speech_tool.py +++ b/tests/unit/tools/test_text_to_speech_tool.py @@ -5,6 +5,7 @@ import pytest +from griptape.artifacts.audio_artifact import AudioArtifact from griptape.tools.text_to_speech.tool import TextToSpeechTool @@ -32,7 +33,7 @@ def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile) - text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] + text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) diff --git a/tests/unit/tools/test_variation_image_generation_tool.py b/tests/unit/tools/test_variation_image_generation_tool.py index c4528a044e..5fd3513c1c 100644 --- a/tests/unit/tools/test_variation_image_generation_tool.py +++ b/tests/unit/tools/test_variation_image_generation_tool.py @@ -58,8 +58,8 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_variation_from_file(