Skip to content

Commit

Permalink
Improve ListArtifact
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 21, 2024
1 parent 7f2ea97 commit 9c33e91
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- `BaseConversationMemory.prompt_driver` for use with autopruning.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.

### Fixed
- Parsing streaming response with some OpenAi compatible services.
Expand Down
22 changes: 14 additions & 8 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from collections.abc import Iterator
from typing import TYPE_CHECKING, Generic, Optional, TypeVar

from attrs import Attribute, define, field

Expand All @@ -9,15 +10,17 @@
if TYPE_CHECKING:
from collections.abc import Sequence

T = TypeVar("T", bound=BaseArtifact)


@define
class ListArtifact(BaseArtifact):
value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True})
class ListArtifact(BaseArtifact, Generic[T]):
value: Sequence[T] = field(factory=list, metadata={"serializable": True})
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

@value.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None:
def validate_value(self, _: Attribute, value: list[T]) -> None:
if self.validate_uniform_types and len(value) > 0:
first_type = type(value[0])

Expand All @@ -31,18 +34,21 @@ def child_type(self) -> Optional[type]:
else:
return None

def __getitem__(self, key: int) -> BaseArtifact:
def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

def to_text(self) -> str:
return self.item_separator.join([v.to_text() for v in self.value])

def __add__(self, other: BaseArtifact) -> BaseArtifact:
return ListArtifact(self.value + other.value)

def is_type(self, target_type: type) -> bool:
if self.value:
return isinstance(self.value[0], target_type)
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/artifacts/test_list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def test___add__(self):
assert artifact.value[0].value == "foo"
assert artifact.value[1].value == "bar"

def test___iter__(self):
assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"]

def test_type_var(self):
assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo"

def test_validate_value(self):
with pytest.raises(ValueError):
ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True)
Expand Down

0 comments on commit 9c33e91

Please sign in to comment.