Skip to content

Commit

Permalink
Move converters into class
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 9, 2024
1 parent 9fdd49b commit f7b515c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
3 changes: 2 additions & 1 deletion griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BaseArtifact(SerializableMixin, ABC):
meta: The metadata associated with the Artifact. Defaults to an empty dictionary.
name: The name of the Artifact. Defaults to the id.
value: The value of the Artifact.
encoding: The encoding of the Artifact when converting to bytes.
"""

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
Expand All @@ -47,7 +48,7 @@ def __len__(self) -> int:
return len(self.value)

def to_bytes(self) -> bytes:
return self.to_text().encode()
return self.to_text().encode(self.encoding)

@abstractmethod
def to_text(self) -> str: ...
16 changes: 8 additions & 8 deletions griptape/artifacts/blob_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@
from griptape.artifacts import BaseArtifact


def value_to_bytes(value: Any) -> bytes:
if isinstance(value, bytes):
return value
else:
return str(value).encode()


@define
class BlobArtifact(BaseArtifact):
"""Stores arbitrary binary data.
Expand All @@ -24,14 +17,21 @@ class BlobArtifact(BaseArtifact):
encoding_error_handler: The error handler to use when converting the binary data to text.
"""

value: bytes = field(converter=value_to_bytes, metadata={"serializable": True})
value: bytes = field(converter=lambda value: BlobArtifact.value_to_bytes(value), metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)

@property
def mime_type(self) -> str:
return "application/octet-stream"

@classmethod
def value_to_bytes(cls, value: Any) -> bytes:
if isinstance(value, bytes):
return value
else:
return str(value).encode()

def to_bytes(self) -> bytes:
return self.value

Expand Down
16 changes: 8 additions & 8 deletions griptape/artifacts/csv_row_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@
from griptape.artifacts import BaseArtifact, TextArtifact


def value_to_str(value: Any) -> str:
if isinstance(value, dict):
return "\n".join(f"{key}: {val}" for key, val in value.items())
else:
return str(value)


@define
class CsvRowArtifact(TextArtifact):
"""Stores a row of a CSV file.
Expand All @@ -22,7 +15,14 @@ class CsvRowArtifact(TextArtifact):
value: The row of the CSV file. If a dictionary is passed, the keys and values converted to a string.
"""

value: str = field(converter=value_to_str, metadata={"serializable": True})
value: str = field(converter=lambda value: CsvRowArtifact.value_to_str(value), metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> TextArtifact:
return TextArtifact(self.value + "\n" + other.value)

@classmethod
def value_to_str(cls, value: Any) -> str:
if isinstance(value, dict):
return "\n".join(f"{key}: {val}" for key, val in value.items())
else:
return str(value)
16 changes: 8 additions & 8 deletions griptape/artifacts/json_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]


def value_to_json(value: Any) -> Json:
if isinstance(value, str):
return json.loads(value)
else:
return json.loads(json.dumps(value))


@define
class JsonArtifact(BaseArtifact):
"""Stores JSON data.
Expand All @@ -25,7 +18,14 @@ class JsonArtifact(BaseArtifact):
value: The JSON data. Values will automatically be converted to a JSON-compatible format.
"""

value: Json = field(converter=value_to_json, metadata={"serializable": True})
value: Json = field(converter=lambda value: JsonArtifact.value_to_json(value), metadata={"serializable": True})

@classmethod
def value_to_json(cls, value: Any) -> Json:
if isinstance(value, str):
return json.loads(value)
else:
return json.loads(json.dumps(value))

def to_text(self) -> str:
return json.dumps(self.value)

0 comments on commit f7b515c

Please sign in to comment.