Skip to content

Commit

Permalink
Refactor Artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 3, 2024
1 parent a7bfc14 commit 15bb112
Show file tree
Hide file tree
Showing 56 changed files with 335 additions and 453 deletions.
Empty file added .ignore
Empty file.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Parameter `meta: dict` on `BaseEvent`.
- `TableArtifact` for storing CSV data.

### Changed
- **BREAKING**: Parameter `driver` on `BaseConversationMemory` renamed to `conversation_memory_driver`.
- **BREAKING**: `BaseConversationMemory.add_to_prompt_stack` now takes a `prompt_driver` parameter.
- **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], Optional[dict]]`.
- **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: Optional[dict]` as input.
- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`.
- **BREAKING**: `CsvLoader` now returns a `TableArtifact` instead of a `list[CsvRowArtifact]`.
- `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`.
- `CsvRowArtifact.to_text()` now includes the header.

Expand Down
7 changes: 0 additions & 7 deletions docs/griptape-framework/data/artifacts.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used

An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blob).

## Boolean

A [BooleanArtifact](../../reference/griptape/artifacts/boolean_artifact.md) is used for passing boolean values around the framework.

!!! info
Any object passed on init to `BooleanArtifact` will be coerced into a `bool` type. This might lead to unintended behavior: `BooleanArtifact("False").value is True`. Use [BooleanArtifact.parse_bool](../../reference/griptape/artifacts/boolean_artifact.md#griptape.artifacts.boolean_artifact.BooleanArtifact.parse_bool) to convert case-insensitive string literal values `"True"` and `"False"` into a `BooleanArtifact`: `BooleanArtifact.parse_bool("False").value is False`.

## Generic

A [GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md) can be used as an escape hatch for passing any type of data around the framework.
Expand Down
17 changes: 12 additions & 5 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,39 @@
from .base_artifact import BaseArtifact

from .base_system_artifact import BaseSystemArtifact
from .error_artifact import ErrorArtifact
from .info_artifact import InfoArtifact

from .text_artifact import TextArtifact
from .json_artifact import JsonArtifact
from .blob_artifact import BlobArtifact
from .boolean_artifact import BooleanArtifact
from .csv_row_artifact import CsvRowArtifact
from .table_artifact import TableArtifact

from .list_artifact import ListArtifact
from .media_artifact import MediaArtifact

from .blob_artifact import BlobArtifact

from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact

from .action_artifact import ActionArtifact

from .generic_artifact import GenericArtifact


__all__ = [
"BaseArtifact",
"BaseSystemArtifact",
"ErrorArtifact",
"InfoArtifact",
"TextArtifact",
"JsonArtifact",
"BlobArtifact",
"BooleanArtifact",
"CsvRowArtifact",
"ListArtifact",
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
"ActionArtifact",
"GenericArtifact",
"TableArtifact",
]
7 changes: 3 additions & 4 deletions griptape/artifacts/action_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.common import ToolAction


@define()
class ActionArtifact(BaseArtifact, SerializableMixin):
class ActionArtifact(BaseArtifact):
value: ToolAction = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionArtifact:
raise NotImplementedError
def to_text(self) -> str:
return str(self.value)
18 changes: 13 additions & 5 deletions griptape/artifacts/audio_artifact.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from __future__ import annotations

from attrs import define
from attrs import define, field

from griptape.artifacts import MediaArtifact
from griptape.artifacts import BaseArtifact


@define
class AudioArtifact(MediaArtifact):
"""AudioArtifact is a type of MediaArtifact representing audio."""
class AudioArtifact(BaseArtifact):
"""AudioArtifact is a type of Artifact representing audio."""

media_type: str = "audio"
value: bytes = field(metadata={"serializable": True})
format: str = field(kw_only=True, metadata={"serializable": True})

@property
def mime_type(self) -> str:
return f"audio/{self.format}"

def to_text(self) -> str:
return f"Audio, format: {self.format}, size: {len(self.value)} bytes"
22 changes: 4 additions & 18 deletions griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -25,22 +24,6 @@ class BaseArtifact(SerializableMixin, ABC):
)
value: Any = field()

@classmethod
def value_to_bytes(cls, value: Any) -> bytes:
if isinstance(value, bytes):
return value
else:
return str(value).encode()

@classmethod
def value_to_dict(cls, value: Any) -> dict:
dict_value = value if isinstance(value, dict) else json.loads(value)

return dict(dict_value.items())

def to_text(self) -> str:
return str(self.value)

def __str__(self) -> str:
return self.to_text()

Expand All @@ -50,5 +33,8 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(self.value)

def to_bytes(self) -> bytes:
return self.to_text().encode("utf-8")

@abstractmethod
def __add__(self, other: BaseArtifact) -> BaseArtifact: ...
def to_text(self) -> str: ...
10 changes: 10 additions & 0 deletions griptape/artifacts/base_system_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC

from griptape.artifacts import BaseArtifact


class BaseSystemArtifact(BaseArtifact, ABC):
"""Base class for Artifacts specific to Griptape."""

def to_text(self) -> str:
return self.value
26 changes: 14 additions & 12 deletions griptape/artifacts/blob_artifact.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
from __future__ import annotations

import os.path
from typing import Optional
from typing import Any

from attrs import define, field
from attrs import Converter, define, field

from griptape.artifacts import BaseArtifact


@define
class BlobArtifact(BaseArtifact):
value: bytes = field(converter=BaseArtifact.value_to_bytes, metadata={"serializable": True})
dir_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
value: bytes = field(
converter=Converter(lambda value: BlobArtifact.value_to_bytes(value)),
metadata={"serializable": True},
)
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)

def __add__(self, other: BaseArtifact) -> BlobArtifact:
return BlobArtifact(self.value + other.value, name=self.name)

@property
def full_path(self) -> str:
return os.path.join(self.dir_name, self.name) if self.dir_name else self.name
media_type: str = field(default="application/octet-stream", kw_only=True)

@classmethod
def value_to_bytes(cls, value: Any) -> bytes:
if isinstance(value, bytes):
return value
else:
return str(value).encode()

def to_text(self) -> str:
return self.value.decode(encoding=self.encoding, errors=self.encoding_error_handler)
31 changes: 0 additions & 31 deletions griptape/artifacts/boolean_artifact.py

This file was deleted.

19 changes: 13 additions & 6 deletions griptape/artifacts/csv_row_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,30 @@

import csv
import io
import json
from typing import Any

from attrs import define, field
from attrs import Converter, define, field

from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.artifacts import TextArtifact


@define
class CsvRowArtifact(TextArtifact):
value: dict[str, str] = field(converter=BaseArtifact.value_to_dict, metadata={"serializable": True})
value: dict[str, str] = field(
converter=Converter(lambda value: CsvRowArtifact.value_to_dict(value)), metadata={"serializable": True}
)
delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> CsvRowArtifact:
return CsvRowArtifact(self.value | other.value)

def __bool__(self) -> bool:
return len(self) > 0

@classmethod
def value_to_dict(cls, value: Any) -> dict:
dict_value = value if isinstance(value, dict) else json.loads(value)

return dict(dict_value.items())

def to_text(self) -> str:
with io.StringIO() as csvfile:
writer = csv.DictWriter(
Expand Down
7 changes: 2 additions & 5 deletions griptape/artifacts/error_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.artifacts import BaseSystemArtifact


@define
class ErrorArtifact(BaseArtifact):
class ErrorArtifact(BaseSystemArtifact):
value: str = field(converter=str, metadata={"serializable": True})
exception: Optional[Exception] = field(default=None, kw_only=True, metadata={"serializable": False})

def __add__(self, other: BaseArtifact) -> ErrorArtifact:
return ErrorArtifact(self.value + other.value)
4 changes: 2 additions & 2 deletions griptape/artifacts/generic_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
class GenericArtifact(BaseArtifact):
value: Any = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> BaseArtifact:
raise NotImplementedError
def to_text(self) -> str:
return str(self.value)
30 changes: 21 additions & 9 deletions griptape/artifacts/image_artifact.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
from __future__ import annotations

import base64

from attrs import define, field

from griptape.artifacts import MediaArtifact
from griptape.artifacts import BaseArtifact


@define
class ImageArtifact(MediaArtifact):
"""ImageArtifact is a type of MediaArtifact representing an image.
class ImageArtifact(BaseArtifact):
"""ImageArtifact is a type of Artifact representing an image.
Attributes:
value: Raw bytes representing media data.
media_type: The type of media, defaults to "image".
format: The format of the media, like png, jpeg, or gif.
name: Artifact name, generated using creation time and a random string.
model: Optionally specify the model used to generate the media.
prompt: Optionally specify the prompt used to generate the media.
format: The format of the media, like png, jpeg, or gif. Default is png.
width: The width of the image in pixels.
height: The height of the image in pixels.
"""

media_type: str = "image"
value: bytes = field(metadata={"serializable": True})
format: str = field(default="png", kw_only=True, metadata={"serializable": True})
width: int = field(kw_only=True, metadata={"serializable": True})
height: int = field(kw_only=True, metadata={"serializable": True})

@property
def base64(self) -> str:
return base64.b64encode(self.value).decode("utf-8")

@property
def mime_type(self) -> str:
return f"image/{self.format}"

def to_text(self) -> str:
return self.base64
7 changes: 2 additions & 5 deletions griptape/artifacts/info_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.artifacts import BaseSystemArtifact


@define
class InfoArtifact(BaseArtifact):
class InfoArtifact(BaseSystemArtifact):
value: str = field(converter=str, metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> InfoArtifact:
return InfoArtifact(self.value + other.value)
23 changes: 13 additions & 10 deletions griptape/artifacts/json_artifact.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from __future__ import annotations

import json
from typing import Union
from typing import Any, Union

from attrs import define, field
from attrs import Converter, define, field

from griptape.artifacts import BaseArtifact

Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]
from griptape.artifacts.text_artifact import TextArtifact


@define
class JsonArtifact(BaseArtifact):
value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True})
class JsonArtifact(TextArtifact):
Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]

value: Json = field(
converter=Converter(lambda value: JsonArtifact.value_to_dict(value)), metadata={"serializable": True}
)

@classmethod
def value_to_dict(cls, value: Any) -> dict:
return json.loads(json.dumps(value))

def to_text(self) -> str:
return json.dumps(self.value)

def __add__(self, other: BaseArtifact) -> JsonArtifact:
raise NotImplementedError
Loading

0 comments on commit 15bb112

Please sign in to comment.