From 6093f8e451fe2d1fbe25b7dbfe089a27c2a3af2f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:02:52 -0700 Subject: [PATCH] Bring back CsvRowArtifact --- CHANGELOG.md | 4 +-- MIGRATION.md | 32 +++++-------------- griptape/artifacts/__init__.py | 2 ++ griptape/artifacts/csv_row_artifact.py | 28 ++++++++++++++++ griptape/artifacts/text_artifact.py | 11 ++----- .../extraction/csv_extraction_engine.py | 4 +-- griptape/loaders/csv_loader.py | 10 +++--- griptape/loaders/dataframe_loader.py | 10 +++--- griptape/loaders/sql_loader.py | 10 +++--- tests/unit/artifacts/test_csv_row_artifact.py | 27 ++++++++++++++++ tests/unit/tasks/test_prompt_task.py | 4 +-- 11 files changed, 88 insertions(+), 54 deletions(-) create mode 100644 griptape/artifacts/csv_row_artifact.py create mode 100644 tests/unit/artifacts/test_csv_row_artifact.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 685c3341db..f648de1245 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,14 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseArtifact.to_bytes()` method to convert an Artifact to bytes. ### Changed +- **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. -- **BREAKING**: Removed `CsvRowArtifact`. - **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. - **BREAKING**: Removed `AudioArtifact.media_type`. - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. +- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. @@ -42,7 +43,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. -- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAI compatible services. diff --git a/MIGRATION.md b/MIGRATION.md index ddf9b9c003..65584297bf 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -36,39 +36,23 @@ audio_artifact = AudioArtifact( ) ``` -### Removed `CsvRowArtifact` +### Changed `CsvRowArtifact.value` from `dict` to `str`. -`CsvRowArtifact` has been removed. Use `TextArtifact` instead. +`CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead. #### Before ```python -CsvRowArtifact({"name": "John", "age": 30}) +artifact = CsvRowArtifact({"name": "John", "age": 30}) +print(artifact.value) # {"name": "John", "age": 30} +print(type(artifact.value)) # ``` #### After ```python -TextArtifact("name: John\nAge: 30") -``` - -### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types - -`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a tuple of `list[TextArtifact]` instead of `list[CsvRowArtifact]`. - -#### Before - -```python -results = CsvLoader().load(Path("people.csv").read_text()) - -print(results[0].value) # {"name": "John", "age": 30} -``` - -#### After -```python -results = CsvLoader().load(Path("people.csv").read_text()) - -print(results[0].value) # name: John\nAge: 30 -print(results[0].meta["row"]) # 0 +artifact = CsvRowArtifact({"name": "John", "age": 30}) +print(artifact.value) # name: John\nAge: 30 +print(type(artifact.value)) # ``` ### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 0e58a8a764..c0e5f767e2 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -5,6 +5,7 @@ from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact +from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact @@ -20,6 +21,7 @@ "JsonArtifact", "BlobArtifact", "BooleanArtifact", + "CsvRowArtifact", "ListArtifact", "ImageArtifact", "AudioArtifact", diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py new file mode 100644 index 0000000000..b1e2e7dfe8 --- /dev/null +++ b/griptape/artifacts/csv_row_artifact.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any + +from attrs import define, field + +from griptape.artifacts import BaseArtifact, TextArtifact + + +def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return "\n".join(f"{key}: {val}" for key, val in value.items()) + else: + return str(value) + + +@define +class CsvRowArtifact(TextArtifact): + """Stores a row of a CSV file. + + Attributes: + value: The row of the CSV file. If a dictionary is passed, the keys and values converted to a string. + """ + + value: str = field(converter=value_to_str, metadata={"serializable": True}) + + def __add__(self, other: BaseArtifact) -> TextArtifact: + return TextArtifact(self.value + "\n" + other.value) diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 31d20e5320..7a3b62e3b8 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -11,16 +11,9 @@ from griptape.tokenizers import BaseTokenizer -def value_to_str(value: Any) -> str: - if isinstance(value, dict): - return "\n".join(f"{key}: {val}" for key, val in value.items()) - else: - return str(value) - - @define class TextArtifact(BaseArtifact): - value: str = field(converter=value_to_str, metadata={"serializable": True}) + value: str = field(converter=str, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) embedding: Optional[list[float]] = field(default=None, kw_only=True) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 21d7820ac1..f5f2a3e59f 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import ListArtifact, TextArtifact +from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -41,7 +41,7 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(TextArtifact(dict(zip(column_names, [x.strip() for x in row])))) + rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row])))) return rows diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 20ac237c56..52fc982dcd 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,7 +6,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -19,7 +19,7 @@ class CsvLoader(BaseLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] if isinstance(source, bytes): @@ -28,7 +28,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)] + chunks = [CsvRowArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)] if self.embedding_driver: for chunk in chunks: @@ -44,8 +44,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[TextArtifact]]: + ) -> dict[str, list[CsvRowArtifact]]: return cast( - dict[str, list[TextArtifact]], + dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 5ecb35ecd5..0b1ae14484 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -19,10 +19,10 @@ class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] - chunks = [TextArtifact(row) for row in source.to_dict(orient="records")] + chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] if self.embedding_driver: for chunk in chunks: @@ -33,8 +33,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: + return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 14320911eb..542fd5d5b6 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts.text_artifact import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -16,11 +16,11 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] - chunks = [TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else [] + chunks = [CsvRowArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else [] if self.embedding_driver: for chunk in chunks: @@ -31,5 +31,5 @@ def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: + return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py new file mode 100644 index 0000000000..4591285259 --- /dev/null +++ b/tests/unit/artifacts/test_csv_row_artifact.py @@ -0,0 +1,27 @@ +from griptape.artifacts import CsvRowArtifact + + +class TestCsvRowArtifact: + def test_value_type_conversion(self): + assert CsvRowArtifact({"foo": "bar"}).value == "foo: bar" + assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == "foo: {'bar': 'baz'}" + assert CsvRowArtifact('{"foo": "bar"}').value == '{"foo": "bar"}' + + def test___add__(self): + assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == "test1: foo\ntest2: bar" + + def test_to_text(self): + assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}).to_text() == "test1: foo|bar\ntest2: 1" + + def test_to_dict(self): + assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == "test1: foo" + + def test_name(self): + artifact = CsvRowArtifact({}) + + assert artifact.name == artifact.id + assert CsvRowArtifact({}, name="bar").name == "bar" + + def test___bool__(self): + assert not bool(CsvRowArtifact({})) + assert bool(CsvRowArtifact({"foo": "bar"})) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 2d4456ce27..cfe8532260 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -114,9 +114,9 @@ def test_input(self): assert task.input.value[1].width == 100 # default case - task = PromptTask({"default": "test"}) # pyright: ignore[reportArgumentType] + task = PromptTask({"default": "test"}) - assert task.input.value == "default: test" + assert task.input.value == str({"default": "test"}) def test_prompt_stack(self): task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")])