Skip to content

Commit

Permalink
Replace get_bucket_location with head_bucket (mlflow#10731)
Browse files Browse the repository at this point in the history
  • Loading branch information
kriscon-db authored Jan 6, 2024
1 parent bdee1d2 commit 2782a01
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
8 changes: 6 additions & 2 deletions mlflow/store/artifact/optimized_s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ def __init__(
self._region_name = self._get_region_name()

def _get_region_name(self):
# note: s3 client enforces path addressing style for get_bucket_location
from botocore.exceptions import ClientError

temp_client = _get_s3_client(
addressing_style="path",
access_key_id=self._access_key_id,
secret_access_key=self._secret_access_key,
session_token=self._session_token,
s3_endpoint_url=self._s3_endpoint_url,
)
return temp_client.get_bucket_location(Bucket=self.bucket)["LocationConstraint"]
try:
return temp_client.head_bucket(Bucket=self.bucket)["BucketRegion"]
except ClientError as error:
return error.response["ResponseMetadata"]["HTTPHeaders"]["x-amz-bucket-region"]

def _get_s3_client(self):
return _get_s3_client(
Expand Down
4 changes: 3 additions & 1 deletion requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ openai<1.0
# Required for showing pytest stats
psutil
# SQLAlchemy == 2.0.25 requires typing_extensions >= 4.6.0
typing_extensions>=4.6.0
typing_extensions>=4.6.0
# Required for importing boto3 ClientError directly for testing
botocore>=1.34
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def run(self):
"requests-auth-aws-sigv4",
# Required to log artifacts and models to AWS S3 artifact locations
"boto3",
"botocore",
# Required to log artifacts and models to GCS artifact locations
"google-cloud-storage>=1.30.0",
"azureml-core>=1.2.0",
Expand All @@ -169,6 +170,7 @@ def run(self):
"azure-storage-file-datalake>12",
"google-cloud-storage>=1.30.0",
"boto3>1",
"botocore>1.34",
],
"gateway": GATEWAY_REQUIREMENTS,
"genai": GATEWAY_REQUIREMENTS,
Expand Down
21 changes: 18 additions & 3 deletions tests/store/artifact/test_optimized_s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_get_s3_client_hits_cache(s3_artifact_root, monkeypatch):
with mock.patch("boto3.client") as mock_get_s3_client:
s3_client_mock = mock.Mock()
mock_get_s3_client.return_value = s3_client_mock
s3_client_mock.get_bucket_location.return_value = {"LocationConstraint": "us-west-2"}
s3_client_mock.head_bucket.return_value = {"BucketRegion": "us-west-2"}

# pylint: disable=no-value-for-parameter
repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path"))
Expand Down Expand Up @@ -98,12 +98,27 @@ def test_get_s3_client_verify_param_set_correctly(
)


def test_get_s3_client_region_name_set_correctly(s3_artifact_root):
@pytest.mark.parametrize("client_throws", [True, False])
def test_get_s3_client_region_name_set_correctly(s3_artifact_root, client_throws):
region_name = "us_random_region_42"
with mock.patch("boto3.client") as mock_get_s3_client:
from botocore.exceptions import ClientError

s3_client_mock = mock.Mock()
mock_get_s3_client.return_value = s3_client_mock
s3_client_mock.get_bucket_location.return_value = {"LocationConstraint": region_name}
if client_throws:
error = ClientError(
{
"Error": {"Code": "403", "Message": "Forbidden"},
"ResponseMetadata": {
"HTTPHeaders": {"x-amz-bucket-region": region_name},
},
},
"head_bucket",
)
s3_client_mock.head_bucket.side_effect = error
else:
s3_client_mock.head_bucket.return_value = {"BucketRegion": region_name}

repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path"))
repo._get_s3_client()
Expand Down

0 comments on commit 2782a01

Please sign in to comment.