Skip to content

Commit

Permalink
Improve ListArtifact
Browse files Browse the repository at this point in the history
Make type covariant
  • Loading branch information
collindutter committed Sep 6, 2024
1 parent 0b57c4a commit 58a5a63
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- Passing a dictionary as the value to `TextArtifact` will convert to a key-value formatted string.
- Removed `__add__` method from `BaseArtifact`, implemented it where necessary.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.

## [0.31.0] - 2024-09-03

Expand All @@ -40,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`.
- `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`.
- `CsvRowArtifact.to_text()` now includes the header.
- `BaseConversationMemory.prompt_driver` for use with autopruning.

### Fixed
- Parsing streaming response with some OpenAI compatible services.
Expand Down
19 changes: 12 additions & 7 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Generic, Optional, TypeVar

from attrs import Attribute, define, field

from griptape.artifacts import BaseArtifact

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterator, Sequence

T = TypeVar("T", bound=BaseArtifact, covariant=True)


@define
class ListArtifact(BaseArtifact):
value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True})
class ListArtifact(BaseArtifact, Generic[T]):
value: Sequence[T] = field(factory=list, metadata={"serializable": True})
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

def __getitem__(self, key: int) -> BaseArtifact:
def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> BaseArtifact:
def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

@value.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None:
def validate_value(self, _: Attribute, value: list[T]) -> None:
if self.validate_uniform_types and len(value) > 0:
first_type = type(value[0])

Expand Down
6 changes: 5 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from collections.abc import Sequence
from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin
from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin

import attrs
from marshmallow import INCLUDE, Schema, fields
Expand Down Expand Up @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:

field_class, args, optional = cls._get_field_type_info(field_type)

# Resolve TypeVars to their bound type
if isinstance(field_class, TypeVar):
field_class = field_class.__bound__

if attrs.has(field_class):
if ABC in field_class.__bases__:
return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional)
Expand Down
6 changes: 5 additions & 1 deletion griptape/tasks/tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def run(self) -> BaseArtifact:
subtask.after_run()

if isinstance(subtask.output, ListArtifact):
self.output = subtask.output[0]
first_artifact = subtask.output[0]
if isinstance(first_artifact, BaseArtifact):
self.output = first_artifact
else:
self.output = ErrorArtifact(f"Output is not an Artifact: {type(subtask.output[0])}")
else:
self.output = InfoArtifact("No tool output")
except Exception as e:
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/artifacts/test_list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def test___add__(self):
assert artifact.value[0].value == "foo"
assert artifact.value[1].value == "bar"

def test___iter__(self):
assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"]

def test_type_var(self):
assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo"

def test_validate_value(self):
with pytest.raises(ValueError):
ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True)
Expand Down

0 comments on commit 58a5a63

Please sign in to comment.