Skip to content

Commit

Permalink
Bring back CsvRowArtifact
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 6, 2024
1 parent 58a5a63 commit 00a418a
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 45 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseArtifact.to_bytes()` method to convert an Artifact to bytes.

### Changed
- **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`.
- **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead.
- **BREAKING**: Removed `CsvRowArtifact`.
- **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`.
- **BREAKING**: Removed `ImageArtifact.media_type`.
- **BREAKING**: Removed `AudioArtifact.media_type`.
- **BREAKING**: Removed `BlobArtifact.dir_name`.
- **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`.
- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image.
- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image.
- Updated `JsonArtifact` value converter to properly handle more types.
- `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
Expand All @@ -42,7 +43,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`.
- `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`.
- `CsvRowArtifact.to_text()` now includes the header.
- `BaseConversationMemory.prompt_driver` for use with autopruning.

### Fixed
- Parsing streaming response with some OpenAI compatible services.
Expand Down
32 changes: 8 additions & 24 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,23 @@ audio_artifact = AudioArtifact(
)
```

### Removed `CsvRowArtifact`
### Changed `CsvRowArtifact.value` from `dict` to `str`.

`CsvRowArtifact` has been removed. Use `TextArtifact` instead.
`CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead.

#### Before

```python
CsvRowArtifact({"name": "John", "age": 30})
artifact = CsvRowArtifact({"name": "John", "age": 30})
print(artifact.value) # {"name": "John", "age": 30}
print(type(artifact.value)) # <class 'dict'>
```

#### After
```python
TextArtifact("name: John\nAge: 30")
```

### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types

`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a tuple of `list[TextArtifact]` instead of `list[CsvRowArtifact]`.

#### Before

```python
results = CsvLoader().load(Path("people.csv").read_text())

print(results[0].value) # {"name": "John", "age": 30}
```

#### After
```python
results = CsvLoader().load(Path("people.csv").read_text())

print(results[0].value) # name: John\nAge: 30
print(results[0].meta["row"]) # 0
artifact = CsvRowArtifact({"name": "John", "age": 30})
print(artifact.value) # name: John\nAge: 30
print(type(artifact.value)) # <class 'str'>
```

### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta`
Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .json_artifact import JsonArtifact
from .blob_artifact import BlobArtifact
from .boolean_artifact import BooleanArtifact
from .csv_row_artifact import CsvRowArtifact
from .list_artifact import ListArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
Expand All @@ -20,6 +21,7 @@
"JsonArtifact",
"BlobArtifact",
"BooleanArtifact",
"CsvRowArtifact",
"ListArtifact",
"ImageArtifact",
"AudioArtifact",
Expand Down
22 changes: 22 additions & 0 deletions griptape/artifacts/csv_row_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from typing import Any

from attrs import define, field

from griptape.artifacts import BaseArtifact, TextArtifact


def value_to_str(value: Any) -> str:
if isinstance(value, dict):
return "\n".join(f"{key}: {val}" for key, val in value.items())
else:
return str(value)


@define
class CsvRowArtifact(TextArtifact):
value: str = field(converter=value_to_str, metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> TextArtifact:
return TextArtifact(self.value + "\n" + other.value)
11 changes: 2 additions & 9 deletions griptape/artifacts/text_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional

from attrs import define, field

Expand All @@ -11,16 +11,9 @@
from griptape.tokenizers import BaseTokenizer


def value_to_str(value: Any) -> str:
if isinstance(value, dict):
return "\n".join(f"{key}: {val}" for key, val in value.items())
else:
return str(value)


@define
class TextArtifact(BaseArtifact):
value: str = field(converter=value_to_str, metadata={"serializable": True})
value: str = field(converter=str, metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)
embedding: Optional[list[float]] = field(default=None, kw_only=True)
Expand Down
4 changes: 2 additions & 2 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attrs import Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact
from griptape.common import Message, PromptStack
from griptape.engines import BaseExtractionEngine
from griptape.utils import J2
Expand Down Expand Up @@ -41,7 +41,7 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(TextArtifact(dict(zip(column_names, [x.strip() for x in row]))))
rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row]))))

return rows

Expand Down
4 changes: 2 additions & 2 deletions griptape/loaders/csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attrs import define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import CsvRowArtifact, TextArtifact
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
Expand All @@ -28,7 +28,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]:
raise ValueError(f"Unsupported source type: {type(source)}")

reader = csv.DictReader(StringIO(source), delimiter=self.delimiter)
chunks = [TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]
chunks = [CsvRowArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]

if self.embedding_driver:
for chunk in chunks:
Expand Down
4 changes: 2 additions & 2 deletions griptape/loaders/dataframe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import CsvRowArtifact, TextArtifact
from griptape.loaders import BaseLoader
from griptape.utils import import_optional_dependency
from griptape.utils.hash import str_to_hash
Expand All @@ -22,7 +22,7 @@ class DataFrameLoader(BaseLoader):
def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]:
artifacts = []

chunks = [TextArtifact(row) for row in source.to_dict(orient="records")]
chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")]

if self.embedding_driver:
for chunk in chunks:
Expand Down
4 changes: 2 additions & 2 deletions griptape/loaders/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.artifacts import CsvRowArtifact, TextArtifact
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
Expand All @@ -20,7 +20,7 @@ def load(self, source: str, *args, **kwargs) -> list[TextArtifact]:
rows = self.sql_driver.execute_query(source)
artifacts = []

chunks = [TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else []
chunks = [CsvRowArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else []

if self.embedding_driver:
for chunk in chunks:
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/artifacts/test_csv_row_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from griptape.artifacts import CsvRowArtifact


class TestCsvRowArtifact:
def test_value_type_conversion(self):
assert CsvRowArtifact({"foo": "bar"}).value == "foo: bar"
assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == "foo: {'bar': 'baz'}"
assert CsvRowArtifact('{"foo": "bar"}').value == '{"foo": "bar"}'

def test___add__(self):
assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == "test1: foo\ntest2: bar"

def test_to_text(self):
assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}).to_text() == "test1: foo|bar\ntest2: 1"

def test_to_dict(self):
assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == "test1: foo"

def test_name(self):
artifact = CsvRowArtifact({})

assert artifact.name == artifact.id
assert CsvRowArtifact({}, name="bar").name == "bar"

def test___bool__(self):
assert not bool(CsvRowArtifact({}))
assert bool(CsvRowArtifact({"foo": "bar"}))
4 changes: 2 additions & 2 deletions tests/unit/tasks/test_prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def test_input(self):
assert task.input.value[1].width == 100

# default case
task = PromptTask({"default": "test"}) # pyright: ignore[reportArgumentType]
task = PromptTask({"default": "test"})

assert task.input.value == "default: test"
assert task.input.value == str({"default": "test"})

def test_prompt_stack(self):
task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")])
Expand Down

0 comments on commit 00a418a

Please sign in to comment.