diff --git a/CHANGELOG.md b/CHANGELOG.md index db8d3f6cd6..23e5366aa5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,11 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- `BaseArtifact.to_bytes()` method to convert an Artifact value to bytes. +- `BaseArtifact.to_bytes()` method to convert an Artifact's value to bytes. - `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 strings. +- `CsvLoader`/`SqlLoader`/`DataframeLoader` `formatter_fn` field for customizing how SQL results are formatted into `TextArtifact`s. ### Changed -- **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`. +- **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. - **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. @@ -22,7 +23,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. -- Passing a dictionary as the value to `CsvRowArtifact` will convert to a key-value formatted string. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. - Generic type support to `ListArtifact`. - Iteration support to `ListArtifact`. diff --git a/MIGRATION.md b/MIGRATION.md index 08b4d9a7f0..016a93f03d 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -56,9 +56,9 @@ image_artifact = ImageArtifact( ) ``` -### Changed `CsvRowArtifact.value` from `dict` to `str`. +### Removed `CsvRowArtifact` -`CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead. +`CsvRowArtifact` has been removed. Use `TextArtifact` instead. #### Before @@ -70,11 +70,45 @@ print(type(artifact.value)) # #### After ```python -artifact = CsvRowArtifact({"name": "John", "age": 30}) -print(artifact.value) # name: John\nAge: 30 +artifact = TextArtifact("name: John\nage: 30") +print(artifact.value) # name: John\nage: 30 print(type(artifact.value)) # ``` +If you require storing a dictionary as an Artifact, you can use `GenericArtifact` instead. + +### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types + +`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`. + +If you require a dictionary, set a custom `formatter_fn` and then parse the text to a dictionary. + +#### Before + +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # +``` + +#### After +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # name: John\nAge: 30 +print(type(results[0].value)) # + +# Customize formatter_fn +results = CsvLoader(formatter_fn=lambda x: json.dumps(x)).load(Path("people.csv").read_text()) +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # + +dict_results = [json.loads(result.value) for result in results] +print(dict_results[0]) # {"name": "John", "age": 30} +print(type(dict_results[0])) # +``` + ### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` `ImageArtifact.prompt` and `ImageArtifact.model` have been moved to `ImageArtifact.meta`. diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index c0e5f767e2..0e58a8a764 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -5,7 +5,6 @@ from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact -from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact @@ -21,7 +20,6 @@ "JsonArtifact", "BlobArtifact", "BooleanArtifact", - "CsvRowArtifact", "ListArtifact", "ImageArtifact", "AudioArtifact", diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 9f6085da6f..cf8f164cca 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,11 +2,11 @@ import csv import io -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, cast from attrs import Factory, define, field -from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -20,6 +20,9 @@ class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(default=Factory(list), kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) + formatter_fn: Callable[[Any], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) def extract( self, @@ -37,22 +40,22 @@ def extract( item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: + def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row])))) + rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row]))))) return rows def _extract_rec( self, artifacts: list[TextArtifact], - rows: list[CsvRowArtifact], + rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[CsvRowArtifact]: + ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( column_names=self.column_names, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 6f6fae5f37..bcf7029d4a 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,11 +2,11 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -18,8 +18,11 @@ class CsvLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: artifacts = [] if isinstance(source, bytes): @@ -28,7 +31,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [CsvRowArtifact(row, meta={"row_num": row_num}) for row_num, row in enumerate(reader)] + chunks = [TextArtifact(self.formatter_fn(row)) for row in reader] if self.embedding_driver: for chunk in chunks: @@ -44,8 +47,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[CsvRowArtifact]]: + ) -> dict[str, list[TextArtifact]]: return cast( - dict[str, list[CsvRowArtifact]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 0b1ae14484..30d705676f 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -18,11 +18,14 @@ @define class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: artifacts = [] - chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] + chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")] if self.embedding_driver: for chunk in chunks: @@ -33,8 +36,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 418ff0baf3..105f585cb3 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -15,14 +15,15 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] - chunks = ( - [CsvRowArtifact(row.cells, meta={"row_num": row_num}) for row_num, row in enumerate(rows)] if rows else [] - ) + chunks = [TextArtifact(self.formatter_fn(row.cells)) for row in rows] if rows else [] if self.embedding_driver: for chunk in chunks: @@ -33,5 +34,5 @@ def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py deleted file mode 100644 index 4591285259..0000000000 --- a/tests/unit/artifacts/test_csv_row_artifact.py +++ /dev/null @@ -1,27 +0,0 @@ -from griptape.artifacts import CsvRowArtifact - - -class TestCsvRowArtifact: - def test_value_type_conversion(self): - assert CsvRowArtifact({"foo": "bar"}).value == "foo: bar" - assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == "foo: {'bar': 'baz'}" - assert CsvRowArtifact('{"foo": "bar"}').value == '{"foo": "bar"}' - - def test___add__(self): - assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == "test1: foo\ntest2: bar" - - def test_to_text(self): - assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}).to_text() == "test1: foo|bar\ntest2: 1" - - def test_to_dict(self): - assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == "test1: foo" - - def test_name(self): - artifact = CsvRowArtifact({}) - - assert artifact.name == artifact.id - assert CsvRowArtifact({}, name="bar").name == "bar" - - def test___bool__(self): - assert not bool(CsvRowArtifact({})) - assert bool(CsvRowArtifact({"foo": "bar"})) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index d08b939197..7af409152e 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,3 +1,5 @@ +import json + import pytest from griptape.loaders.csv_loader import CsvLoader @@ -55,3 +57,12 @@ def test_load_collection(self, loader, create_source): assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] + + def test_formatter_fn(self, loader, create_source): + loader.formatter_fn = lambda value: json.dumps(value) + source = create_source("test-1.csv") + + artifacts = loader.load(source) + + assert len(artifacts) == 10 + assert artifacts[0].value == '{"Foo": "foo1", "Bar": "bar1"}' diff --git a/tests/unit/tools/test_sql_tool.py b/tests/unit/tools/test_sql_tool.py index f60401b664..061b31f4dc 100644 --- a/tests/unit/tools/test_sql_tool.py +++ b/tests/unit/tools/test_sql_tool.py @@ -27,7 +27,6 @@ def test_execute_query(self, driver): assert len(result.value) == 1 assert result.value[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert result.value[0].meta["row_num"] == 0 def test_execute_query_description(self, driver): client = SqlTool(