Skip to content

Commit

Permalink
changelog, missed tool
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Sep 16, 2024
1 parent 0592c43 commit 63ebe4f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Removed `BlobArtifact.dir_name`.
- **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`.
- **BREAKING**: `ImageArtifact.format` is now required.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockImageGenerationDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockImageQueryDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockPromptDriver` to `client`.
- **BREAKING**: Renamed parameter `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver` to `client`.
- **BREAKING**: Renamed parameter `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver` to `client`.
- **BREAKING**: Renamed parameter `sqs_client` on `AmazonSqsEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `iotdata_client` on `AwsIotCoreEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `s3_client` on `AmazonS3FileManagerDriver` to `client`.
- **BREAKING**: Renamed parameter `s3_client` on `AwsS3Tool` to `client`.
- **BREAKING**: Renamed parameter `pusher_client` on `PusherEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`.
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `client`.
- **BREAKING**: Renamed parameter `collection` on `AstraDbVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `mq` on `MarqoVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `engine` on `PgVectorVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `index` on `PineconeVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GoogleTokenizer` to `client`.
- Updated `JsonArtifact` value converter to properly handle more types.
- `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- Removed `__add__` method from `BaseArtifact`, implemented it where necessary.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.
- The `client` parameter on `Driver`s that use a client are now lazily initialized.

## [0.31.0] - 2024-09-03

Expand Down
26 changes: 15 additions & 11 deletions griptape/tools/aws_s3/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,24 @@
import io
from typing import TYPE_CHECKING, Any

from attrs import Factory, define, field
from attrs import define, field
from schema import Literal, Schema

from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact
from griptape.tools import BaseAwsTool
from griptape.utils.decorators import activity
from griptape.utils.decorators import activity, lazy_property

if TYPE_CHECKING:
from mypy_boto3_s3 import Client


@define
class AwsS3Tool(BaseAwsTool):
s3_client: Client = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True)
_client: Client = field(default=None, kw_only=True)

@lazy_property()
def client(self) -> Client:
return self.session.client("s3")

@activity(
config={
Expand All @@ -33,7 +37,7 @@ class AwsS3Tool(BaseAwsTool):
)
def get_bucket_acl(self, params: dict) -> TextArtifact | ErrorArtifact:
try:
acl = self.s3_client.get_bucket_acl(Bucket=params["values"]["bucket_name"])
acl = self.client.get_bucket_acl(Bucket=params["values"]["bucket_name"])
return TextArtifact(acl)
except Exception as e:
return ErrorArtifact(f"error getting bucket acl: {e}")
Expand All @@ -48,7 +52,7 @@ def get_bucket_acl(self, params: dict) -> TextArtifact | ErrorArtifact:
)
def get_bucket_policy(self, params: dict) -> TextArtifact | ErrorArtifact:
try:
policy = self.s3_client.get_bucket_policy(Bucket=params["values"]["bucket_name"])
policy = self.client.get_bucket_policy(Bucket=params["values"]["bucket_name"])
return TextArtifact(policy)
except Exception as e:
return ErrorArtifact(f"error getting bucket policy: {e}")
Expand All @@ -66,7 +70,7 @@ def get_bucket_policy(self, params: dict) -> TextArtifact | ErrorArtifact:
)
def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact:
try:
acl = self.s3_client.get_object_acl(
acl = self.client.get_object_acl(
Bucket=params["values"]["bucket_name"],
Key=params["values"]["object_key"],
)
Expand All @@ -77,7 +81,7 @@ def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact:
@activity(config={"description": "Can be used to list all AWS S3 buckets."})
def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact:
try:
buckets = self.s3_client.list_buckets()
buckets = self.client.list_buckets()

return ListArtifact([TextArtifact(str(b)) for b in buckets["Buckets"]])
except Exception as e:
Expand All @@ -91,7 +95,7 @@ def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact:
)
def list_objects(self, params: dict) -> ListArtifact | ErrorArtifact:
try:
objects = self.s3_client.list_objects_v2(Bucket=params["values"]["bucket_name"])
objects = self.client.list_objects_v2(Bucket=params["values"]["bucket_name"])

if "Contents" not in objects:
return ErrorArtifact("no objects found in the bucket")
Expand Down Expand Up @@ -192,7 +196,7 @@ def download_objects(self, params: dict) -> ListArtifact | ErrorArtifact:
artifacts = []
for object_info in objects:
try:
obj = self.s3_client.get_object(Bucket=object_info["bucket_name"], Key=object_info["object_key"])
obj = self.client.get_object(Bucket=object_info["bucket_name"], Key=object_info["object_key"])

content = obj["Body"].read()
artifacts.append(BlobArtifact(content, name=object_info["object_key"]))
Expand All @@ -203,9 +207,9 @@ def download_objects(self, params: dict) -> ListArtifact | ErrorArtifact:
return ListArtifact(artifacts)

def _upload_object(self, bucket_name: str, object_name: str, value: Any) -> None:
self.s3_client.create_bucket(Bucket=bucket_name)
self.client.create_bucket(Bucket=bucket_name)

self.s3_client.upload_fileobj(
self.client.upload_fileobj(

Check warning on line 212 in griptape/tools/aws_s3/tool.py

View check run for this annotation

Codecov / codecov/patch

griptape/tools/aws_s3/tool.py#L212

Added line #L212 was not covered by tests
Fileobj=io.BytesIO(value.encode() if isinstance(value, str) else value),
Bucket=bucket_name,
Key=object_name,
Expand Down

0 comments on commit 63ebe4f

Please sign in to comment.