From 8f7a05962a2eb366b735f6034d45e97032ac96fb Mon Sep 17 00:00:00 2001 From: "Bala.FA" Date: Sat, 27 Jul 2024 04:55:10 +0000 Subject: [PATCH] Add path-style addressing support This PR extends support to other S3 object storage like MinIO which has path-style addressing to access bucket/object. An example is like ```py from s3torchconnector import S3MapDataset, S3IterableDataset DATASET_URI="s3:///" REGION = "us-east-1" iterable_dataset = S3IterableDataset.from_prefix(DATASET_URI, region=REGION, path_style=True) for item in iterable_dataset: print(item.key) map_dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION, path_style=True) item = map_dataset[0] bucket = item.bucket key = item.key content = item.read() len(content) ``` And ```py from s3torchconnector import S3Checkpoint import torchvision import torch CHECKPOINT_URI="s3:////" REGION = "us-east-1" checkpoint = S3Checkpoint(region=REGION, path_style=True) model = torchvision.models.resnet18() with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer: torch.save(model.state_dict(), writer) with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader: state_dict = torch.load(reader) model.load_state_dict(state_dict) ``` Fixes #208 Signed-off-by: Bala.FA --- .../src/s3torchconnector/_s3client/_mock_s3client.py | 3 +++ .../src/s3torchconnector/_s3client/_s3client.py | 3 +++ .../lightning/s3_lightning_checkpoint.py | 2 ++ s3torchconnector/src/s3torchconnector/s3checkpoint.py | 6 +++++- .../src/s3torchconnector/s3iterable_dataset.py | 9 +++++++++ .../src/s3torchconnector/s3map_dataset.py | 9 +++++++++ .../rust/src/mountpoint_s3_client.rs | 11 ++++++++--- 7 files changed, 39 insertions(+), 4 deletions(-) diff --git a/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py b/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py index 3ebbddd5..e77afa1a 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py @@ -25,11 +25,13 @@ def __init__( bucket: str, user_agent: Optional[UserAgent] = None, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): super().__init__( region, user_agent=user_agent, s3client_config=s3client_config, + path_style=path_style, ) self._mock_client = MockMountpointS3Client( region, @@ -38,6 +40,7 @@ def __init__( part_size=self.s3client_config.part_size, user_agent_prefix=self.user_agent_prefix, unsigned=self.s3client_config.unsigned, + path_style=path_style, ) def add_object(self, key: str, data: bytes) -> None: diff --git a/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py b/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py index 3466bafb..229750f8 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py @@ -40,6 +40,7 @@ def __init__( endpoint: Optional[str] = None, user_agent: Optional[UserAgent] = None, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): self._region = region self._endpoint = endpoint @@ -48,6 +49,7 @@ def __init__( user_agent = user_agent or UserAgent() self._user_agent_prefix = user_agent.prefix self._s3client_config = s3client_config or S3ClientConfig() + self._path_style = path_style @property def _client(self) -> MountpointS3Client: @@ -78,6 +80,7 @@ def _client_builder(self) -> MountpointS3Client: throughput_target_gbps=self._s3client_config.throughput_target_gbps, part_size=self._s3client_config.part_size, unsigned=self._s3client_config.unsigned, + path_style=self._path_style, ) def get_object( diff --git a/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py b/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py index 15bbe28a..634afb29 100644 --- a/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py +++ b/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py @@ -21,6 +21,7 @@ def __init__( region: str, s3client_config: Optional[S3ClientConfig] = None, endpoint: Optional[str] = None, + path_style: Optional[bool] = False, ): self.region = region user_agent = UserAgent(["lightning", lightning.__version__]) @@ -29,6 +30,7 @@ def __init__( user_agent=user_agent, s3client_config=s3client_config, endpoint=endpoint, + path_style=path_style, ) def save_checkpoint( diff --git a/s3torchconnector/src/s3torchconnector/s3checkpoint.py b/s3torchconnector/src/s3torchconnector/s3checkpoint.py index d4ec9b27..334cdd67 100644 --- a/s3torchconnector/src/s3torchconnector/s3checkpoint.py +++ b/s3torchconnector/src/s3torchconnector/s3checkpoint.py @@ -22,11 +22,15 @@ def __init__( region: str, endpoint: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): self.region = region self.endpoint = endpoint self._client = S3Client( - region, endpoint=endpoint, s3client_config=s3client_config + region, + endpoint=endpoint, + s3client_config=s3client_config, + path_style=path_style, ) def reader(self, s3_uri: str) -> S3Reader: diff --git a/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py b/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py index d61487a8..c4eb16ba 100644 --- a/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py +++ b/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py @@ -32,12 +32,14 @@ def __init__( endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint self._s3client_config = s3client_config + self._path_style = path_style self._client = None @property @@ -57,6 +59,7 @@ def from_objects( endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): """Returns an instance of S3IterableDataset using the S3 URI(s) provided. @@ -66,6 +69,7 @@ def from_objects( endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. s3client_config: Optional S3ClientConfig with parameters for S3 client. + path_style: Optional path_style for S3 client. Returns: S3IterableDataset: An IterableStyle dataset created from S3 objects. @@ -80,6 +84,7 @@ def from_objects( endpoint, transform=transform, s3client_config=s3client_config, + path_style=path_style, ) @classmethod @@ -91,6 +96,7 @@ def from_prefix( endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): """Returns an instance of S3IterableDataset using the S3 URI provided. @@ -100,6 +106,7 @@ def from_prefix( endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. s3client_config: Optional S3ClientConfig with parameters for S3 client. + path_style: Optional path_style for S3 client. Returns: S3IterableDataset: An IterableStyle dataset created from S3 objects. @@ -114,6 +121,7 @@ def from_prefix( endpoint, transform=transform, s3client_config=s3client_config, + path_style=path_style, ) def _get_client(self): @@ -122,6 +130,7 @@ def _get_client(self): self.region, endpoint=self.endpoint, s3client_config=self._s3client_config, + path_style=self._path_style, ) return self._client diff --git a/s3torchconnector/src/s3torchconnector/s3map_dataset.py b/s3torchconnector/src/s3torchconnector/s3map_dataset.py index 6d6b837f..a350f744 100644 --- a/s3torchconnector/src/s3torchconnector/s3map_dataset.py +++ b/s3torchconnector/src/s3torchconnector/s3map_dataset.py @@ -33,12 +33,14 @@ def __init__( endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint self._s3client_config = s3client_config + self._path_style = path_style self._client = None self._bucket_key_pairs: Optional[List[S3BucketKeyData]] = None @@ -66,6 +68,7 @@ def from_objects( endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): """Returns an instance of S3MapDataset using the S3 URI(s) provided. @@ -75,6 +78,7 @@ def from_objects( endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. s3client_config: Optional S3ClientConfig with parameters for S3 client. + path_style: Optional path_style for S3 client. Returns: S3MapDataset: A Map-Style dataset created from S3 objects. @@ -89,6 +93,7 @@ def from_objects( endpoint, transform=transform, s3client_config=s3client_config, + path_style=path_style, ) @classmethod @@ -100,6 +105,7 @@ def from_prefix( endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, + path_style: Optional[bool] = False, ): """Returns an instance of S3MapDataset using the S3 URI provided. @@ -109,6 +115,7 @@ def from_prefix( endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. s3client_config: Optional S3ClientConfig with parameters for S3 client. + path_style: Optional path_style for S3 client. Returns: S3MapDataset: A Map-Style dataset created from S3 objects. @@ -123,6 +130,7 @@ def from_prefix( endpoint, transform=transform, s3client_config=s3client_config, + path_style=path_style, ) def _get_client(self): @@ -131,6 +139,7 @@ def _get_client(self): self.region, endpoint=self.endpoint, s3client_config=self._s3client_config, + path_style=self._path_style, ) return self._client diff --git a/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs b/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs index 4c496c5d..f10fd5c7 100644 --- a/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs +++ b/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use mountpoint_s3_crt::common::uri::Uri; use mountpoint_s3_crt::common::allocator::Allocator; -use mountpoint_s3_client::config::{EndpointConfig, S3ClientAuthConfig, S3ClientConfig}; +use mountpoint_s3_client::config::{AddressingStyle, EndpointConfig, S3ClientAuthConfig, S3ClientConfig}; use mountpoint_s3_client::types::PutObjectParams; use mountpoint_s3_client::user_agent::UserAgent; use mountpoint_s3_client::{ObjectClient, S3CrtClient}; @@ -53,7 +53,8 @@ pub struct MountpointS3Client { #[pymethods] impl MountpointS3Client { #[new] - #[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None))] + #[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None, path_style=false))] + #[allow(clippy::too_many_arguments)] pub fn new_s3_client( region: String, user_agent_prefix: String, @@ -62,16 +63,20 @@ impl MountpointS3Client { profile: Option, unsigned: bool, endpoint: Option, + path_style: bool, ) -> PyResult { // TODO: Mountpoint has logic for guessing based on instance type. It may be worth having // similar logic if we want to exceed 10Gbps reading for larger instances let endpoint_str = endpoint.as_deref().unwrap_or(""); - let endpoint_config = if endpoint_str.is_empty() { + let mut endpoint_config = if endpoint_str.is_empty() { EndpointConfig::new(®ion) } else { EndpointConfig::new(®ion).endpoint(Uri::new_from_str(&Allocator::default(), endpoint_str).unwrap()) }; + if path_style { + endpoint_config = endpoint_config.addressing_style(AddressingStyle::Path); + } let auth_config = auth_config(profile.as_deref(), unsigned); let user_agent_suffix =