Skip to content

Commit

Permalink
Add path-style addressing support
Browse files Browse the repository at this point in the history
This PR extends support to other S3 object storage like MinIO which
has path-style addressing to access bucket/object.

An example is like
```py
from s3torchconnector import S3MapDataset, S3IterableDataset

DATASET_URI="s3://<BUCKET>/<PREFIX>"
REGION = "us-east-1"

iterable_dataset = S3IterableDataset.from_prefix(DATASET_URI, region=REGION, path_style=True)

for item in iterable_dataset:
  print(item.key)

map_dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION, path_style=True)

item = map_dataset[0]

bucket = item.bucket
key = item.key
content = item.read()
len(content)
```

And

```py
from s3torchconnector import S3Checkpoint

import torchvision
import torch

CHECKPOINT_URI="s3://<BUCKET>/<KEY>/"
REGION = "us-east-1"
checkpoint = S3Checkpoint(region=REGION, path_style=True)

model = torchvision.models.resnet18()

with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer:
    torch.save(model.state_dict(), writer)

with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader:
    state_dict = torch.load(reader)

model.load_state_dict(state_dict)
```

Fixes awslabs#208

Signed-off-by: Bala.FA <[email protected]>
  • Loading branch information
balamurugana committed Jul 31, 2024
1 parent b40ab8b commit 8f7a059
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def __init__(
bucket: str,
user_agent: Optional[UserAgent] = None,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
super().__init__(
region,
user_agent=user_agent,
s3client_config=s3client_config,
path_style=path_style,
)
self._mock_client = MockMountpointS3Client(
region,
Expand All @@ -38,6 +40,7 @@ def __init__(
part_size=self.s3client_config.part_size,
user_agent_prefix=self.user_agent_prefix,
unsigned=self.s3client_config.unsigned,
path_style=path_style,
)

def add_object(self, key: str, data: bytes) -> None:
Expand Down
3 changes: 3 additions & 0 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
endpoint: Optional[str] = None,
user_agent: Optional[UserAgent] = None,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
self._region = region
self._endpoint = endpoint
Expand All @@ -48,6 +49,7 @@ def __init__(
user_agent = user_agent or UserAgent()
self._user_agent_prefix = user_agent.prefix
self._s3client_config = s3client_config or S3ClientConfig()
self._path_style = path_style

@property
def _client(self) -> MountpointS3Client:
Expand Down Expand Up @@ -78,6 +80,7 @@ def _client_builder(self) -> MountpointS3Client:
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
part_size=self._s3client_config.part_size,
unsigned=self._s3client_config.unsigned,
path_style=self._path_style,
)

def get_object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
region: str,
s3client_config: Optional[S3ClientConfig] = None,
endpoint: Optional[str] = None,
path_style: Optional[bool] = False,
):
self.region = region
user_agent = UserAgent(["lightning", lightning.__version__])
Expand All @@ -29,6 +30,7 @@ def __init__(
user_agent=user_agent,
s3client_config=s3client_config,
endpoint=endpoint,
path_style=path_style,
)

def save_checkpoint(
Expand Down
6 changes: 5 additions & 1 deletion s3torchconnector/src/s3torchconnector/s3checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ def __init__(
region: str,
endpoint: Optional[str] = None,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
self.region = region
self.endpoint = endpoint
self._client = S3Client(
region, endpoint=endpoint, s3client_config=s3client_config
region,
endpoint=endpoint,
s3client_config=s3client_config,
path_style=path_style,
)

def reader(self, s3_uri: str) -> S3Reader:
Expand Down
9 changes: 9 additions & 0 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ def __init__(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._s3client_config = s3client_config
self._path_style = path_style
self._client = None

@property
Expand All @@ -57,6 +59,7 @@ def from_objects(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
Expand All @@ -66,6 +69,7 @@ def from_objects(
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
path_style: Optional path_style for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -80,6 +84,7 @@ def from_objects(
endpoint,
transform=transform,
s3client_config=s3client_config,
path_style=path_style,
)

@classmethod
Expand All @@ -91,6 +96,7 @@ def from_prefix(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
"""Returns an instance of S3IterableDataset using the S3 URI provided.
Expand All @@ -100,6 +106,7 @@ def from_prefix(
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
path_style: Optional path_style for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -114,6 +121,7 @@ def from_prefix(
endpoint,
transform=transform,
s3client_config=s3client_config,
path_style=path_style,
)

def _get_client(self):
Expand All @@ -122,6 +130,7 @@ def _get_client(self):
self.region,
endpoint=self.endpoint,
s3client_config=self._s3client_config,
path_style=self._path_style,
)
return self._client

Expand Down
9 changes: 9 additions & 0 deletions s3torchconnector/src/s3torchconnector/s3map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def __init__(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._s3client_config = s3client_config
self._path_style = path_style
self._client = None
self._bucket_key_pairs: Optional[List[S3BucketKeyData]] = None

Expand Down Expand Up @@ -66,6 +68,7 @@ def from_objects(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
"""Returns an instance of S3MapDataset using the S3 URI(s) provided.
Expand All @@ -75,6 +78,7 @@ def from_objects(
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
path_style: Optional path_style for S3 client.
Returns:
S3MapDataset: A Map-Style dataset created from S3 objects.
Expand All @@ -89,6 +93,7 @@ def from_objects(
endpoint,
transform=transform,
s3client_config=s3client_config,
path_style=path_style,
)

@classmethod
Expand All @@ -100,6 +105,7 @@ def from_prefix(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
path_style: Optional[bool] = False,
):
"""Returns an instance of S3MapDataset using the S3 URI provided.
Expand All @@ -109,6 +115,7 @@ def from_prefix(
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
path_style: Optional path_style for S3 client.
Returns:
S3MapDataset: A Map-Style dataset created from S3 objects.
Expand All @@ -123,6 +130,7 @@ def from_prefix(
endpoint,
transform=transform,
s3client_config=s3client_config,
path_style=path_style,
)

def _get_client(self):
Expand All @@ -131,6 +139,7 @@ def _get_client(self):
self.region,
endpoint=self.endpoint,
s3client_config=self._s3client_config,
path_style=self._path_style,
)
return self._client

Expand Down
11 changes: 8 additions & 3 deletions s3torchconnectorclient/rust/src/mountpoint_s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::Arc;

use mountpoint_s3_crt::common::uri::Uri;
use mountpoint_s3_crt::common::allocator::Allocator;
use mountpoint_s3_client::config::{EndpointConfig, S3ClientAuthConfig, S3ClientConfig};
use mountpoint_s3_client::config::{AddressingStyle, EndpointConfig, S3ClientAuthConfig, S3ClientConfig};
use mountpoint_s3_client::types::PutObjectParams;
use mountpoint_s3_client::user_agent::UserAgent;
use mountpoint_s3_client::{ObjectClient, S3CrtClient};
Expand Down Expand Up @@ -53,7 +53,8 @@ pub struct MountpointS3Client {
#[pymethods]
impl MountpointS3Client {
#[new]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None))]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None, path_style=false))]
#[allow(clippy::too_many_arguments)]
pub fn new_s3_client(
region: String,
user_agent_prefix: String,
Expand All @@ -62,16 +63,20 @@ impl MountpointS3Client {
profile: Option<String>,
unsigned: bool,
endpoint: Option<String>,
path_style: bool,
) -> PyResult<Self> {
// TODO: Mountpoint has logic for guessing based on instance type. It may be worth having
// similar logic if we want to exceed 10Gbps reading for larger instances

let endpoint_str = endpoint.as_deref().unwrap_or("");
let endpoint_config = if endpoint_str.is_empty() {
let mut endpoint_config = if endpoint_str.is_empty() {
EndpointConfig::new(&region)
} else {
EndpointConfig::new(&region).endpoint(Uri::new_from_str(&Allocator::default(), endpoint_str).unwrap())
};
if path_style {
endpoint_config = endpoint_config.addressing_style(AddressingStyle::Path);
}
let auth_config = auth_config(profile.as_deref(), unsigned);

let user_agent_suffix =
Expand Down

0 comments on commit 8f7a059

Please sign in to comment.