Skip to content

Commit

Permalink
Add GCS authentication for service accounts (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Jun 29, 2023
1 parent 1259545 commit da53bc5
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 74 deletions.
2 changes: 1 addition & 1 deletion STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Streaming uses the [yapf](https://github.com/google/yapf) formatter for general
(see section 2.2). These checks can also be run manually via:

```
pre-commit run yapf --all-files # for yahp
pre-commit run yapf --all-files # for yapf
pre-commit run isort --all-files # for isort
```

Expand Down
18 changes: 17 additions & 1 deletion docs/source/how_to_guides/configure_cloud_storage_credentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,23 @@ export S3_ENDPOINT_URL='https://<accountid>.r2.cloudflarestorage.com'

For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [Google Cloud Storage](https://mcli.docs.mosaicml.com/en/latest/secrets/gcp.html) MCLI doc on how to configure the cloud provider credentials.

### Others

### GCP Service Account Credentials Mounted as Environment Variables

Users must set their GCP `account credentials` to point to their credentials file in the run environment.

````{tabs}
```{code-tab} py
import os
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'KEY_FILE'
```
```{code-tab} sh
export GOOGLE_APPLICATION_CREDENTIALS='KEY_FILE'
```
````

### GCP User Auth Credentials Mounted as Environment Variables

Streaming dataset supports [GCP user credentials](https://cloud.google.com/storage/docs/authentication#user_accounts) or [HMAC keys for User account](https://cloud.google.com/storage/docs/authentication/hmackeys). Users must set their GCP `user access key` and GCP `user access secret` in the run environment.

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
install_requires = [
'boto3>=1.21.45,<2',
'Brotli>=1.0.9',
'google-cloud-storage>=2.9.0',
'matplotlib>=3.5.2,<4',
'paramiko>=2.11.0,<4',
'python-snappy>=0.6.1,<1',
Expand Down
42 changes: 38 additions & 4 deletions streaming/base/storage/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,33 @@ def download_from_gcs(remote: str, local: str) -> None:
remote (str): Remote path (GCS).
local (str): Local path (local filesystem).
"""
import boto3
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError

obj = urllib.parse.urlparse(remote)
if obj.scheme != 'gs':
raise ValueError(
f'Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote}')

if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
_gcs_with_service_account(local, obj)
elif 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
_gcs_with_hmac(remote, local, obj)
else:
raise ValueError(f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
f'service level accounts or GCS_KEY and GCS_SECRET needs to be ' +
f'set for HMAC authentication')


def _gcs_with_hmac(remote: str, local: str, obj: urllib.parse.ParseResult) -> None:
"""Download a file from remote GCS to local using user level credentials.
Args:
remote (str): Remote path (GCS).
local (str): Local path (local filesystem).
obj (ParseResult): ParseResult object of remote.
"""
import boto3
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError

# Create a new session per thread
session = boto3.session.Session()
# Create a resource client using a thread's session object
Expand All @@ -190,6 +208,22 @@ def download_from_gcs(remote: str, local: str) -> None:
raise


def _gcs_with_service_account(local: str, obj: urllib.parse.ParseResult) -> None:
"""Download a file from remote GCS to local using service account credentials.
Args:
local (str): Local path (local filesystem).
obj (ParseResult): ParseResult object of remote path (GCS).
"""
from google.cloud.storage import Blob, Bucket, Client

service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
gcs_client = Client.from_service_account_json(service_account_path)

blob = Blob(obj.path.lstrip('/'), Bucket(gcs_client, obj.netloc))
blob.download_to_filename(local)


def download_from_oci(remote: str, local: str) -> None:
"""Download a file from remote OCI to local.
Expand Down
119 changes: 80 additions & 39 deletions streaming/base/storage/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import sys
import urllib.parse
from enum import Enum
from tempfile import mkdtemp
from typing import Any, Tuple, Union

Expand All @@ -16,8 +17,12 @@
from streaming.base.storage.download import BOTOCORE_CLIENT_ERROR_CODES

__all__ = [
'CloudUploader', 'S3Uploader', 'GCSUploader', 'OCIUploader', 'AzureUploader',
'AzureDataLakeUploader', 'LocalUploader'
'CloudUploader',
'S3Uploader',
'GCSUploader',
'OCIUploader',
'AzureUploader',
'LocalUploader',
]

logger = logging.getLogger(__name__)
Expand All @@ -32,6 +37,11 @@
}


class GCSAuthentication(Enum):
HMAC = 1
SERVICE_ACCOUNT = 2


class CloudUploader:
"""Upload local files to a cloud storage."""

Expand Down Expand Up @@ -84,10 +94,9 @@ def _validate(self, out: Union[str, Tuple[str, str]]) -> None:
obj = urllib.parse.urlparse(out)
else:
if len(out) != 2:
raise ValueError(''.join([
f'Invalid `out` argument. It is either a string of local/remote directory ',
'or a list of two strings with [local, remote].'
]))
raise ValueError(f'Invalid `out` argument. It is either a string of ' +
f'local/remote directory or a list of two strings with ' +
f'[local, remote].')
obj = urllib.parse.urlparse(out[1])
if obj.scheme not in UPLOADERS:
raise ValueError(f'Invalid Cloud provider prefix: {obj.scheme}.')
Expand Down Expand Up @@ -183,6 +192,7 @@ def __init__(self,

import boto3
from botocore.config import Config

config = Config()
# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
Expand Down Expand Up @@ -261,19 +271,34 @@ def __init__(self,
progress_bar: bool = False) -> None:
super().__init__(out, keep_local, progress_bar)

import boto3
if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
from google.cloud.storage import Client

service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
self.gcs_client = Client.from_service_account_json(service_account_path)
self.authentication = GCSAuthentication.SERVICE_ACCOUNT
elif 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
import boto3

# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
session = boto3.session.Session()
self.gcs_client = session.client(
's3',
region_name='auto',
endpoint_url='https://storage.googleapis.com',
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'],
)
self.authentication = GCSAuthentication.HMAC
else:
raise ValueError(f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
f'service level accounts or GCS_KEY and GCS_SECRET needs to ' +
f'be set for HMAC authentication')

# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
session = boto3.session.Session()
self.gcs_client = session.client('s3',
region_name='auto',
endpoint_url='https://storage.googleapis.com',
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'])
self.check_bucket_exists(self.remote) # pyright: ignore

def upload_file(self, filename: str):
def upload_file(self, filename: str) -> None:
"""Upload file from local instance to Google Cloud Storage bucket.
Args:
Expand All @@ -283,21 +308,31 @@ def upload_file(self, filename: str):
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
self.gcs_client.upload_file(
local_filename,
obj.netloc,
obj.path.lstrip('/'),
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)

if self.authentication == GCSAuthentication.HMAC:
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(
total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar),
) as pbar:
self.gcs_client.upload_file(
local_filename,
obj.netloc,
obj.path.lstrip('/'),
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
from google.cloud.storage import Blob, Bucket

blob = Blob(obj.path.lstrip('/'), Bucket(self.gcs_client, obj.netloc))
blob.upload_from_filename(local_filename)

self.clear_local(local=local_filename)

def check_bucket_exists(self, remote: str):
def check_bucket_exists(self, remote: str) -> None:
"""Raise an exception if the bucket does not exist.
Args:
Expand All @@ -306,16 +341,20 @@ def check_bucket_exists(self, remote: str):
Raises:
error: Bucket does not exist.
"""
from botocore.exceptions import ClientError

bucket_name = urllib.parse.urlparse(remote).netloc
try:
self.gcs_client.head_bucket(Bucket=bucket_name)
except ClientError as error:
if error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES:
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
raise error

if self.authentication == GCSAuthentication.HMAC:
from botocore.exceptions import ClientError

try:
self.gcs_client.head_bucket(Bucket=bucket_name)
except ClientError as error:
if (error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES):
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
raise error
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
self.gcs_client.get_bucket(bucket_name)


class OCIUploader(CloudUploader):
Expand Down Expand Up @@ -343,6 +382,7 @@ def __init__(self,
super().__init__(out, keep_local, progress_bar)

import oci

config = oci.config.from_file()
self.client = oci.object_storage.ObjectStorageClient(
config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY)
Expand Down Expand Up @@ -430,7 +470,8 @@ def __init__(self,
# clients are generally thread-safe.
self.azure_service = BlobServiceClient(
account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net",
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'])
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'],
)
self.check_bucket_exists(self.remote) # pyright: ignore

def upload_file(self, filename: str):
Expand Down
Loading

0 comments on commit da53bc5

Please sign in to comment.