Skip to content

Commit

Permalink
Add BaseFileLoader.save()
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 16, 2024
1 parent 5f1aebd commit 343f64a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `Structure.ConversationMemoryStrategy.PER_STRUCTURE`.
- `BranchTask` for selecting which Tasks (if any) to run based on a condition.
- Support for `BranchTask` in `StructureVisualizer`.
- `BaseFileLoader.save()` method for saving an Artifact to a destination.

### Changed

Expand Down
6 changes: 6 additions & 0 deletions griptape/loaders/base_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ def fetch(self, source: str | PathLike) -> bytes:
return data.encode(self.encoding)
else:
return data

def save(self, destination: str | PathLike, artifact: A) -> None:
"""Saves the Artifact to a destination."""
data = artifact.value if isinstance(artifact.value, bytes) else artifact.value.encode(self.encoding)

self.file_manager_driver.save_file(str(destination), data)
37 changes: 37 additions & 0 deletions tests/unit/loaders/test_base_file_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from griptape.loaders.text_loader import TextLoader


class TestBaseFileLoader:
@pytest.fixture(params=["ascii", "utf-8", None])
def loader(self, request):
encoding = request.param
if encoding is None:
return TextLoader()
else:
return TextLoader(encoding=encoding)

@pytest.fixture(params=["path_from_resource_path"])
def create_source(self, request):
return request.getfixturevalue(request.param)

def test_fetch(self, loader, create_source):
source = create_source("test.txt")

data = loader.fetch(source)

assert data.startswith(b"foobar foobar foobar")

def test_save(self, loader, create_source):
source = create_source("test.txt")

data = loader.load(source)

destination = create_source("test_copy.txt")

loader.save(destination, data)

data_copy = loader.load(destination)

assert data.value == data_copy.value

0 comments on commit 343f64a

Please sign in to comment.