Skip to content

Commit

Permalink
Refactor: load from s3 (#267)
Browse files Browse the repository at this point in the history
* Update AWSHandler.py and __init__.py

Add download_file_from_s3 and parse_s3_uri functions

Refactor download_file function to handle S3 URLs

* Add S3 file download functionality

* Add grid cache directory

* simplify grid loading logic and allow loading from URL

* add test recipe and config for URL loading

* Update cache directory to be created in local repo

* Add kwargs parameter to pack_grid method and clean_grid_cache option

* Update Environment.py with grid cache cleaning functionality

* Add clean_grid_cache option to default_values

* Add clean.py script to clean local cache directory

* Add clean_grid_cache option to test_url_load_config.json

* Update clean_grid_cache flag to false

* Linting: remove unused imports

* add back sys import

* Sort imports

* remove unused function

* move aws methods to AWSHandler

* move s3 url check back to autopack

* Update recipe and config

* rename function and add docstring

* remove unused imports

* move aws methods to AWSHandler

* move s3 url check back to autopack

* rename function and add docstring

* remove unused imports

---------

Co-authored-by: Saurabh Mogre <[email protected]>
  • Loading branch information
rugeli and mogres authored Jun 13, 2024
1 parent 6feaa9a commit 91e18a9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 36 deletions.
18 changes: 17 additions & 1 deletion cellpack/autopack/AWSHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
self,
bucket_name,
sub_folder_name=None,
region_name="us-west-2",
region_name=None,
):
self.bucket_name = bucket_name
self.folder_name = sub_folder_name
Expand Down Expand Up @@ -69,6 +69,22 @@ def upload_file(self, file_path):
return False
return file_name

def download_file(self, key, local_file_path):
"""
Download a file from S3
:param key: S3 object key
:param local_file_path: Local file path to save the downloaded file
"""

try:
self.s3_client.download_file(self.bucket_name, key, local_file_path)
print("File downloaded successfully.")
except ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
print("An error occurred while downloading the file.")

def create_presigned_url(self, object_name, expiration=3600):
"""Generate a presigned URL to share an S3 object
:param object_name: string
Expand Down
45 changes: 10 additions & 35 deletions cellpack/autopack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
from collections import OrderedDict
from pathlib import Path

import boto3
import botocore

from cellpack.autopack.DBRecipeHandler import DBRecipeLoader
from cellpack.autopack.interface_objects.database_ids import DATABASE_IDS
Expand Down Expand Up @@ -250,20 +248,6 @@ def updateReplacePath(newPaths):
REPLACE_PATH[w[0]] = w[1]


def download_file_from_s3(s3_uri, local_file_path):
s3_client = boto3.client("s3")
bucket_name, key = parse_s3_uri(s3_uri)

try:
s3_client.download_file(bucket_name, key, local_file_path)
print("File downloaded successfully.")
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
print("An error occurred while downloading the file.")


def parse_s3_uri(s3_uri):
# Remove the "s3://" prefix and split the remaining string into bucket name and key
s3_uri = s3_uri.replace("s3://", "")
Expand All @@ -275,23 +259,18 @@ def parse_s3_uri(s3_uri):
return bucket_name, folder, key


def download_file(url, local_file_path, reporthook):
def is_s3_url(file_path):
return file_path.find("s3://") != -1


def download_file(url, local_file_path, reporthook, database_name="aws"):
if is_s3_url(url):
# download from s3
# bucket_name, folder, key = parse_s3_uri(url)
# s3_handler = DATABASE_IDS.handlers().get(DATABASE_IDS.AWS)
# s3_handler = s3_handler(bucket_name, folder)
s3_client = boto3.client("s3")
db = DATABASE_IDS.handlers().get(database_name)
bucket_name, folder, key = parse_s3_uri(url)
try:
s3_client.download_file(bucket_name, f"{folder}/{key}", local_file_path)
print("File downloaded successfully.")
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
print("An error occurred while downloading the file.")

initialize_db = db(
bucket_name=bucket_name, sub_folder_name=folder, region_name="us-west-2"
)
initialize_db.download_file(f"{folder}/{key}", local_file_path)
elif url_exists(url):
try:
urllib.urlretrieve(url, local_file_path, reporthook=reporthook)
Expand All @@ -308,10 +287,6 @@ def is_full_url(file_path):
return re.match(url_regex, file_path) is not None


def is_s3_url(file_path):
return file_path.find("s3://") != -1


def is_remote_path(file_path):
"""
@param file_path: str
Expand Down

0 comments on commit 91e18a9

Please sign in to comment.