Skip to content

Commit

Permalink
Improve ListArtifact
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 21, 2024
1 parent 7f2ea97 commit 1d189b8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- `BaseConversationMemory.prompt_driver` for use with autopruning.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.

### Fixed
- Parsing streaming response with some OpenAi compatible services.
Expand Down
22 changes: 14 additions & 8 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from collections.abc import Iterator
from typing import TYPE_CHECKING, Generic, Optional, TypeVar

from attrs import Attribute, define, field

Expand All @@ -9,15 +10,17 @@
if TYPE_CHECKING:
from collections.abc import Sequence

T = TypeVar("T", bound=BaseArtifact)


@define
class ListArtifact(BaseArtifact):
value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True})
class ListArtifact(BaseArtifact, Generic[T]):
value: Sequence[T] = field(factory=list, metadata={"serializable": True})
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

@value.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None:
def validate_value(self, _: Attribute, value: list[T]) -> None:
if self.validate_uniform_types and len(value) > 0:
first_type = type(value[0])

Expand All @@ -31,18 +34,21 @@ def child_type(self) -> Optional[type]:
else:
return None

def __getitem__(self, key: int) -> BaseArtifact:
def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

def to_text(self) -> str:
return self.item_separator.join([v.to_text() for v in self.value])

def __add__(self, other: BaseArtifact) -> BaseArtifact:
return ListArtifact(self.value + other.value)

def is_type(self, target_type: type) -> bool:
if self.value:
return isinstance(self.value[0], target_type)
Expand Down
6 changes: 5 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from collections.abc import Sequence
from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin
from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin

import attrs
from marshmallow import INCLUDE, Schema, fields
Expand Down Expand Up @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:

field_class, args, optional = cls._get_field_type_info(field_type)

# Resolve TypeVars to their bound type
if isinstance(field_class, TypeVar):
field_class = field_class.__bound__

if attrs.has(field_class):
if ABC in field_class.__bases__:
return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional)
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/artifacts/test_list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def test___add__(self):
assert artifact.value[0].value == "foo"
assert artifact.value[1].value == "bar"

def test___iter__(self):
assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"]

def test_type_var(self):
assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo"

def test_validate_value(self):
with pytest.raises(ValueError):
ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True)
Expand Down

0 comments on commit 1d189b8

Please sign in to comment.