diff --git a/CHANGELOG.md b/CHANGELOG.md index 0528b1a465..0aedd19704 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - `BaseConversationMemory.prompt_driver` for use with autopruning. +- Generic type support to `ListArtifact`. +- Iteration support to `ListArtifact`. ### Fixed - Parsing streaming response with some OpenAi compatible services. diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 298f29c6ad..7c66c98735 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from collections.abc import Iterator +from typing import TYPE_CHECKING, Generic, Optional, TypeVar from attrs import Attribute, define, field @@ -9,15 +10,17 @@ if TYPE_CHECKING: from collections.abc import Sequence +T = TypeVar("T", bound=BaseArtifact) + @define -class ListArtifact(BaseArtifact): - value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) +class ListArtifact(BaseArtifact, Generic[T]): + value: Sequence[T] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) @value.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None: + def validate_value(self, _: Attribute, value: list[T]) -> None: if self.validate_uniform_types and len(value) > 0: first_type = type(value[0]) @@ -31,18 +34,21 @@ def child_type(self) -> Optional[type]: else: return None - def __getitem__(self, key: int) -> BaseArtifact: + def __getitem__(self, key: int) -> T: return self.value[key] def __bool__(self) -> bool: return len(self) > 0 + def __add__(self, other: BaseArtifact) -> ListArtifact[T]: + return ListArtifact(self.value + other.value) + + def __iter__(self) -> Iterator[T]: + return iter(self.value) + def to_text(self) -> str: return self.item_separator.join([v.to_text() for v in self.value]) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - return ListArtifact(self.value + other.value) - def is_type(self, target_type: type) -> bool: if self.value: return isinstance(self.value[0], target_type) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index f25e8870b7..b285d14762 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -2,7 +2,7 @@ from abc import ABC from collections.abc import Sequence -from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin +from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin import attrs from marshmallow import INCLUDE, Schema, fields @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested: field_class, args, optional = cls._get_field_type_info(field_type) + # Resolve TypeVars to their bound type + if isinstance(field_class, TypeVar): + field_class = field_class.__bound__ + if attrs.has(field_class): if ABC in field_class.__bases__: return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional) diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index 06d2346458..37769e1b83 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -23,6 +23,12 @@ def test___add__(self): assert artifact.value[0].value == "foo" assert artifact.value[1].value == "bar" + def test___iter__(self): + assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"] + + def test_type_var(self): + assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo" + def test_validate_value(self): with pytest.raises(ValueError): ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True)