Skip to content

Commit

Permalink
Refactor UserAgent setup for extensibility (#161)
Browse files Browse the repository at this point in the history
* Refactor UserAgent setup for extensibility

Update MockClient, store just User-Agent prefix string
and add assertions for MountpointClient prefix.
Add hypothesis test for starts with package/version.
  • Loading branch information
dnanuti authored Feb 29, 2024
1 parent 7ac85ef commit 6f2b9a8
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 13 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

### New features

### Bug Fixes
### Bug Fixes / Improvements
* Fix deadlock when enabling CRT debug logs. Removed former experimental method _enable_debug_logging().

* Refactor User-Agent setup for extensibility.

## v1.1.4 (February 26, 2024)

Expand Down
18 changes: 15 additions & 3 deletions s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)

from . import S3Client
from .._user_agent import UserAgent

"""
_mock_s3client.py
Expand All @@ -15,9 +16,20 @@


class MockS3Client(S3Client):
def __init__(self, region: str, bucket: str, part_size: int = 8 * 1024 * 1024):
super().__init__(region)
self._mock_client = MockMountpointS3Client(region, bucket, part_size=part_size)
def __init__(
self,
region: str,
bucket: str,
part_size: int = 8 * 1024 * 1024,
user_agent: UserAgent = None,
):
super().__init__(region, user_agent=user_agent)
self._mock_client = MockMountpointS3Client(
region,
bucket,
part_size=part_size,
user_agent_prefix=self.user_agent_prefix,
)

def add_object(self, key: str, data: bytes) -> None:
self._mock_client.add_object(key, data)
Expand Down
12 changes: 9 additions & 3 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Optional, Any

from s3torchconnector import S3Reader, S3Writer
from s3torchconnector._version import user_agent_prefix

from s3torchconnectorclient._mountpoint_s3_client import (
MountpointS3Client,
Expand All @@ -16,6 +15,7 @@
GetObjectStream,
)

from s3torchconnector._user_agent import UserAgent

"""
_s3client.py
Expand All @@ -32,11 +32,13 @@ def _identity(obj: Any) -> Any:


class S3Client:
def __init__(self, region: str, endpoint: str = None):
def __init__(self, region: str, endpoint: str = None, user_agent: UserAgent = None):
self._region = region
self._endpoint = endpoint
self._real_client = None
self._client_pid = None
user_agent = user_agent or UserAgent()
self._user_agent_prefix = user_agent.prefix

@property
def _client(self) -> MountpointS3Client:
Expand All @@ -50,11 +52,15 @@ def _client(self) -> MountpointS3Client:
def region(self) -> str:
return self._region

@property
def user_agent_prefix(self) -> str:
return self._user_agent_prefix

def _client_builder(self) -> MountpointS3Client:
return MountpointS3Client(
region=self._region,
endpoint=self._endpoint,
user_agent_prefix=user_agent_prefix,
user_agent_prefix=self._user_agent_prefix,
)

def get_object(
Expand Down
22 changes: 22 additions & 0 deletions s3torchconnector/src/s3torchconnector/_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from typing import List

from ._version import __version__

# https://www.rfc-editor.org/rfc/rfc9110#name-user-agent


class UserAgent:
def __init__(self, comments: List[str] = None):
if comments is not None and not isinstance(comments, list):
raise ValueError("Argument comments must be a List[str]")
self._user_agent_prefix = f"{__package__}/{__version__}"
self._comments = comments or []

@property
def prefix(self):
comments_str = "; ".join(filter(None, self._comments))
if comments_str:
return f"{self._user_agent_prefix} ({comments_str})"
return self._user_agent_prefix
1 change: 0 additions & 1 deletion s3torchconnector/src/s3torchconnector/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@

# __package__ is 's3torchconnector'
__version__ = importlib.metadata.version(__package__)
user_agent_prefix = f"{__package__}/{__version__}"
38 changes: 36 additions & 2 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
import logging
from unittest.mock import MagicMock

import pytest

from hypothesis import given
from hypothesis.strategies import lists, text
from unittest.mock import MagicMock

from s3torchconnector._user_agent import UserAgent
from s3torchconnector._version import __version__
from s3torchconnector._s3client import S3Client, MockS3Client

TEST_BUCKET = "test-bucket"
Expand Down Expand Up @@ -48,3 +52,33 @@ def test_list_objects_log(s3_client: S3Client, caplog):
with caplog.at_level(logging.DEBUG):
s3_client.list_objects(TEST_BUCKET, TEST_KEY)
assert f"ListObjects {S3_URI}" in caplog.messages


def test_s3_client_default_user_agent():
s3_client = S3Client(region=TEST_REGION)
expected_user_agent = f"s3torchconnector/{__version__}"
assert s3_client.user_agent_prefix == expected_user_agent
assert s3_client._client.user_agent_prefix == expected_user_agent


def test_s3_client_custom_user_agent():
s3_client = S3Client(
region=TEST_REGION, user_agent=UserAgent(["component/version", "metadata"])
)
expected_user_agent = (
f"s3torchconnector/{__version__} (component/version; metadata)"
)
assert s3_client.user_agent_prefix == expected_user_agent
assert s3_client._client.user_agent_prefix == expected_user_agent


@given(lists(text()))
def test_user_agent_always_starts_with_package_version(comments):
s3_client = S3Client(region=TEST_REGION, user_agent=UserAgent(comments))
expected_prefix = f"s3torchconnector/{__version__}"
assert s3_client.user_agent_prefix.startswith(expected_prefix)
assert s3_client._client.user_agent_prefix.startswith(expected_prefix)
comments_str = "; ".join(filter(None, comments))
if comments_str:
assert comments_str in s3_client.user_agent_prefix
assert comments_str in s3_client._client.user_agent_prefix
41 changes: 41 additions & 0 deletions s3torchconnector/tst/unit/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from __future__ import annotations

from typing import List

import pytest

from s3torchconnector._version import __version__
from s3torchconnector._user_agent import UserAgent

DEFAULT_PREFIX = f"s3torchconnector/{__version__}"


@pytest.mark.parametrize(
"comments, expected_prefix",
[
(None, DEFAULT_PREFIX),
([], DEFAULT_PREFIX),
([""], DEFAULT_PREFIX),
(["", ""], DEFAULT_PREFIX),
(
["component/version", "metadata"],
f"{DEFAULT_PREFIX} (component/version; metadata)",
),
],
)
def test_user_agent_creation(comments: List[str] | None, expected_prefix: str):
user_agent = UserAgent(comments)
assert user_agent.prefix == expected_prefix


def test_default_user_agent_creation():
user_agent = UserAgent()
assert user_agent.prefix == DEFAULT_PREFIX


@pytest.mark.parametrize("invalid_comment", [0, "string"])
def test_invalid_comments_argument(invalid_comment):
with pytest.raises(ValueError, match="Argument comments must be a List\[str\]"):
UserAgent(invalid_comment)
8 changes: 6 additions & 2 deletions s3torchconnectorclient/rust/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@ pub struct PyMockClient {
pub(crate) region: String,
#[pyo3(get)]
pub(crate) part_size: usize,
#[pyo3(get)]
pub(crate) user_agent_prefix: String,
}

#[pymethods]
impl PyMockClient {
#[new]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024))]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))]
pub fn new(
region: String,
bucket: String,
throughput_target_gbps: f64,
part_size: usize,
user_agent_prefix: String,
) -> PyMockClient {
let unordered_list_seed: Option<u64> = None;
let config = MockClientConfig { bucket, part_size, unordered_list_seed };
Expand All @@ -45,13 +48,14 @@ impl PyMockClient {
region,
throughput_target_gbps,
part_size,
user_agent_prefix
}
}

fn create_mocked_client(&self) -> MountpointS3Client {
MountpointS3Client::new(
self.region.clone(),
"mock-client".to_string(),
self.user_agent_prefix.to_string(),
self.throughput_target_gbps,
self.part_size,
None,
Expand Down

0 comments on commit 6f2b9a8

Please sign in to comment.