From 58a5a6317dfa87e017a22d7de3403a32d3703d25 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 21 Aug 2024 09:23:51 -0700 Subject: [PATCH] Improve ListArtifact Make type covariant --- CHANGELOG.md | 3 +++ griptape/artifacts/list_artifact.py | 19 ++++++++++++------- griptape/schemas/base_schema.py | 6 +++++- griptape/tasks/tool_task.py | 6 +++++- tests/unit/artifacts/test_list_artifact.py | 6 ++++++ 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccecb0768e..685c3341db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - Passing a dictionary as the value to `TextArtifact` will convert to a key-value formatted string. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. +- Generic type support to `ListArtifact`. +- Iteration support to `ListArtifact`. ## [0.31.0] - 2024-09-03 @@ -40,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. +- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAI compatible services. diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index ac11d95ccb..0e6f81ca5d 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -1,32 +1,37 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Generic, Optional, TypeVar from attrs import Attribute, define, field from griptape.artifacts import BaseArtifact if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, Sequence + +T = TypeVar("T", bound=BaseArtifact, covariant=True) @define -class ListArtifact(BaseArtifact): - value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) +class ListArtifact(BaseArtifact, Generic[T]): + value: Sequence[T] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - def __getitem__(self, key: int) -> BaseArtifact: + def __getitem__(self, key: int) -> T: return self.value[key] def __bool__(self) -> bool: return len(self) > 0 - def __add__(self, other: BaseArtifact) -> BaseArtifact: + def __add__(self, other: BaseArtifact) -> ListArtifact[T]: return ListArtifact(self.value + other.value) + def __iter__(self) -> Iterator[T]: + return iter(self.value) + @value.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None: + def validate_value(self, _: Attribute, value: list[T]) -> None: if self.validate_uniform_types and len(value) > 0: first_type = type(value[0]) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index f25e8870b7..b285d14762 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -2,7 +2,7 @@ from abc import ABC from collections.abc import Sequence -from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin +from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin import attrs from marshmallow import INCLUDE, Schema, fields @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested: field_class, args, optional = cls._get_field_type_info(field_type) + # Resolve TypeVars to their bound type + if isinstance(field_class, TypeVar): + field_class = field_class.__bound__ + if attrs.has(field_class): if ABC in field_class.__bases__: return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 6dd5000b31..68260ea91d 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -84,7 +84,11 @@ def run(self) -> BaseArtifact: subtask.after_run() if isinstance(subtask.output, ListArtifact): - self.output = subtask.output[0] + first_artifact = subtask.output[0] + if isinstance(first_artifact, BaseArtifact): + self.output = first_artifact + else: + self.output = ErrorArtifact(f"Output is not an Artifact: {type(subtask.output[0])}") else: self.output = InfoArtifact("No tool output") except Exception as e: diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index bf7004ad0b..c56682d2cc 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -24,6 +24,12 @@ def test___add__(self): assert artifact.value[0].value == "foo" assert artifact.value[1].value == "bar" + def test___iter__(self): + assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"] + + def test_type_var(self): + assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo" + def test_validate_value(self): with pytest.raises(ValueError): ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True)