diff --git a/cellpack/autopack/AWSHandler.py b/cellpack/autopack/AWSHandler.py index 0bbecf0b..48f7fd62 100644 --- a/cellpack/autopack/AWSHandler.py +++ b/cellpack/autopack/AWSHandler.py @@ -19,7 +19,7 @@ def __init__( self, bucket_name, sub_folder_name=None, - region_name="us-west-2", + region_name=None, ): self.bucket_name = bucket_name self.folder_name = sub_folder_name @@ -69,6 +69,22 @@ def upload_file(self, file_path): return False return file_name + def download_file(self, key, local_file_path): + """ + Download a file from S3 + :param key: S3 object key + :param local_file_path: Local file path to save the downloaded file + """ + + try: + self.s3_client.download_file(self.bucket_name, key, local_file_path) + print("File downloaded successfully.") + except ClientError as e: + if e.response["Error"]["Code"] == "404": + print("The object does not exist.") + else: + print("An error occurred while downloading the file.") + def create_presigned_url(self, object_name, expiration=3600): """Generate a presigned URL to share an S3 object :param object_name: string diff --git a/cellpack/autopack/__init__.py b/cellpack/autopack/__init__.py index 3878458a..3ff238da 100755 --- a/cellpack/autopack/__init__.py +++ b/cellpack/autopack/__init__.py @@ -47,8 +47,6 @@ from collections import OrderedDict from pathlib import Path -import boto3 -import botocore from cellpack.autopack.DBRecipeHandler import DBRecipeLoader from cellpack.autopack.interface_objects.database_ids import DATABASE_IDS @@ -250,20 +248,6 @@ def updateReplacePath(newPaths): REPLACE_PATH[w[0]] = w[1] -def download_file_from_s3(s3_uri, local_file_path): - s3_client = boto3.client("s3") - bucket_name, key = parse_s3_uri(s3_uri) - - try: - s3_client.download_file(bucket_name, key, local_file_path) - print("File downloaded successfully.") - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404": - print("The object does not exist.") - else: - print("An error occurred while downloading the file.") - - def parse_s3_uri(s3_uri): # Remove the "s3://" prefix and split the remaining string into bucket name and key s3_uri = s3_uri.replace("s3://", "") @@ -275,23 +259,18 @@ def parse_s3_uri(s3_uri): return bucket_name, folder, key -def download_file(url, local_file_path, reporthook): +def is_s3_url(file_path): + return file_path.find("s3://") != -1 + + +def download_file(url, local_file_path, reporthook, database_name="aws"): if is_s3_url(url): - # download from s3 - # bucket_name, folder, key = parse_s3_uri(url) - # s3_handler = DATABASE_IDS.handlers().get(DATABASE_IDS.AWS) - # s3_handler = s3_handler(bucket_name, folder) - s3_client = boto3.client("s3") + db = DATABASE_IDS.handlers().get(database_name) bucket_name, folder, key = parse_s3_uri(url) - try: - s3_client.download_file(bucket_name, f"{folder}/{key}", local_file_path) - print("File downloaded successfully.") - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404": - print("The object does not exist.") - else: - print("An error occurred while downloading the file.") - + initialize_db = db( + bucket_name=bucket_name, sub_folder_name=folder, region_name="us-west-2" + ) + initialize_db.download_file(f"{folder}/{key}", local_file_path) elif url_exists(url): try: urllib.urlretrieve(url, local_file_path, reporthook=reporthook) @@ -308,10 +287,6 @@ def is_full_url(file_path): return re.match(url_regex, file_path) is not None -def is_s3_url(file_path): - return file_path.find("s3://") != -1 - - def is_remote_path(file_path): """ @param file_path: str