Skip to content

Commit

Permalink
Remove CsvRowArtifact for final time
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 11, 2024
1 parent eb05b6b commit ca352c9
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 63 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added
- `BaseArtifact.to_bytes()` method to convert an Artifact value to bytes.
- `BaseArtifact.to_bytes()` method to convert an Artifact's value to bytes.
- `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 strings.
- `CsvLoader`/`SqlLoader`/`DataframeLoader` `formatter_fn` field for customizing how SQL results are formatted into `TextArtifact`s.

### Changed
- **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`.
- **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead.
- **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead.
- **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`.
- **BREAKING**: Removed `ImageArtifact.media_type`.
Expand All @@ -22,7 +23,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated `JsonArtifact` value converter to properly handle more types.
- `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- Passing a dictionary as the value to `CsvRowArtifact` will convert to a key-value formatted string.
- Removed `__add__` method from `BaseArtifact`, implemented it where necessary.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.
Expand Down
42 changes: 38 additions & 4 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ image_artifact = ImageArtifact(
)
```

### Changed `CsvRowArtifact.value` from `dict` to `str`.
### Removed `CsvRowArtifact`

`CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead.
`CsvRowArtifact` has been removed. Use `TextArtifact` instead.

#### Before

Expand All @@ -70,11 +70,45 @@ print(type(artifact.value)) # <class 'dict'>

#### After
```python
artifact = CsvRowArtifact({"name": "John", "age": 30})
print(artifact.value) # name: John\nAge: 30
artifact = TextArtifact("name: John\nage: 30")
print(artifact.value) # name: John\nage: 30
print(type(artifact.value)) # <class 'str'>
```

If you require storing a dictionary as an Artifact, you can use `GenericArtifact` instead.

### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types

`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`.

If you require a dictionary, set a custom `formatter_fn` and then parse the text to a dictionary.

#### Before

```python
results = CsvLoader().load(Path("people.csv").read_text())

print(results[0].value) # {"name": "John", "age": 30}
print(type(results[0].value)) # <class 'dict'>
```

#### After
```python
results = CsvLoader().load(Path("people.csv").read_text())

print(results[0].value) # name: John\nAge: 30
print(type(results[0].value)) # <class 'str'>

# Customize formatter_fn
results = CsvLoader(formatter_fn=lambda x: json.dumps(x)).load(Path("people.csv").read_text())
print(results[0].value) # {"name": "John", "age": 30}
print(type(results[0].value)) # <class 'str'>

dict_results = [json.loads(result.value) for result in results]
print(dict_results[0]) # {"name": "John", "age": 30}
print(type(dict_results[0])) # <class 'dict'>
```

### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta`

`ImageArtifact.prompt` and `ImageArtifact.model` have been moved to `ImageArtifact.meta`.
Expand Down
2 changes: 0 additions & 2 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .json_artifact import JsonArtifact
from .blob_artifact import BlobArtifact
from .boolean_artifact import BooleanArtifact
from .csv_row_artifact import CsvRowArtifact
from .list_artifact import ListArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
Expand All @@ -21,7 +20,6 @@
"JsonArtifact",
"BlobArtifact",
"BooleanArtifact",
"CsvRowArtifact",
"ListArtifact",
"ImageArtifact",
"AudioArtifact",
Expand Down
15 changes: 9 additions & 6 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import csv
import io
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Callable, Optional, cast

from attrs import Factory, define, field

from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact
from griptape.artifacts import ListArtifact, TextArtifact
from griptape.common import Message, PromptStack
from griptape.engines import BaseExtractionEngine
from griptape.utils import J2
Expand All @@ -20,6 +20,9 @@ class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(default=Factory(list), kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
formatter_fn: Callable[[Any], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

def extract(
self,
Expand All @@ -37,22 +40,22 @@ def extract(
item_separator="\n",
)

def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]:
def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]:
rows = []

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row]))))
rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row])))))

return rows

def _extract_rec(
self,
artifacts: list[TextArtifact],
rows: list[CsvRowArtifact],
rows: list[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
) -> list[CsvRowArtifact]:
) -> list[TextArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
column_names=self.column_names,
Expand Down
15 changes: 9 additions & 6 deletions griptape/loaders/csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import csv
from io import StringIO
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Callable, Optional, cast

from attrs import define, field

from griptape.artifacts import CsvRowArtifact
from griptape.artifacts import TextArtifact
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
Expand All @@ -18,8 +18,11 @@ class CsvLoader(BaseLoader):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
delimiter: str = field(default=",", kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)
formatter_fn: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]:
def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]:
artifacts = []

if isinstance(source, bytes):
Expand All @@ -28,7 +31,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]:
raise ValueError(f"Unsupported source type: {type(source)}")

reader = csv.DictReader(StringIO(source), delimiter=self.delimiter)
chunks = [CsvRowArtifact(row, meta={"row_num": row_num}) for row_num, row in enumerate(reader)]
chunks = [TextArtifact(self.formatter_fn(row)) for row in reader]

if self.embedding_driver:
for chunk in chunks:
Expand All @@ -44,8 +47,8 @@ def load_collection(
sources: list[bytes | str],
*args,
**kwargs,
) -> dict[str, list[CsvRowArtifact]]:
) -> dict[str, list[TextArtifact]]:
return cast(
dict[str, list[CsvRowArtifact]],
dict[str, list[TextArtifact]],
super().load_collection(sources, *args, **kwargs),
)
15 changes: 9 additions & 6 deletions griptape/loaders/dataframe_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Callable, Optional, cast

from attrs import define, field

from griptape.artifacts import CsvRowArtifact
from griptape.artifacts import TextArtifact
from griptape.loaders import BaseLoader
from griptape.utils import import_optional_dependency
from griptape.utils.hash import str_to_hash
Expand All @@ -18,11 +18,14 @@
@define
class DataFrameLoader(BaseLoader):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
formatter_fn: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]:
def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]:
artifacts = []

chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")]
chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")]

if self.embedding_driver:
for chunk in chunks:
Expand All @@ -33,8 +36,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]:

return artifacts

def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]:
return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs))
def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]:
return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs))

Check warning on line 40 in griptape/loaders/dataframe_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/dataframe_loader.py#L40

Added line #L40 was not covered by tests

def to_key(self, source: DataFrame, *args, **kwargs) -> str:
hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object
Expand Down
17 changes: 9 additions & 8 deletions griptape/loaders/sql_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Callable, Optional, cast

from attrs import define, field

from griptape.artifacts import CsvRowArtifact
from griptape.artifacts import TextArtifact
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
Expand All @@ -15,14 +15,15 @@
class SqlLoader(BaseLoader):
sql_driver: BaseSqlDriver = field(kw_only=True)
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
formatter_fn: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]:
def load(self, source: str, *args, **kwargs) -> list[TextArtifact]:
rows = self.sql_driver.execute_query(source)
artifacts = []

chunks = (
[CsvRowArtifact(row.cells, meta={"row_num": row_num}) for row_num, row in enumerate(rows)] if rows else []
)
chunks = [TextArtifact(self.formatter_fn(row.cells)) for row in rows] if rows else []

if self.embedding_driver:
for chunk in chunks:
Expand All @@ -33,5 +34,5 @@ def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]:

return artifacts

def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]:
return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs))
def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]:
return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs))
27 changes: 0 additions & 27 deletions tests/unit/artifacts/test_csv_row_artifact.py

This file was deleted.

11 changes: 11 additions & 0 deletions tests/unit/loaders/test_csv_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import pytest

from griptape.loaders.csv_loader import CsvLoader
Expand Down Expand Up @@ -55,3 +57,12 @@ def test_load_collection(self, loader, create_source):

assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1"
assert collection[loader.to_key(sources[1])][0].embedding == [0, 1]

def test_formatter_fn(self, loader, create_source):
loader.formatter_fn = lambda value: json.dumps(value)
source = create_source("test-1.csv")

artifacts = loader.load(source)

assert len(artifacts) == 10
assert artifacts[0].value == '{"Foo": "foo1", "Bar": "bar1"}'
1 change: 0 additions & 1 deletion tests/unit/tools/test_sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_execute_query(self, driver):

assert len(result.value) == 1
assert result.value[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York"
assert result.value[0].meta["row_num"] == 0

def test_execute_query_description(self, driver):
client = SqlTool(
Expand Down

0 comments on commit ca352c9

Please sign in to comment.