diff --git a/cellpack/autopack/AWSHandler.py b/cellpack/autopack/AWSHandler.py index c93a6098..b4c6397e 100644 --- a/cellpack/autopack/AWSHandler.py +++ b/cellpack/autopack/AWSHandler.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from urllib.parse import parse_qs, urlparse, urlunparse import boto3 from botocore.exceptions import ClientError @@ -40,7 +41,7 @@ def _create_session(self, region_name): def get_aws_object_key(self, object_name): if self.folder_name is not None: - object_name = self.folder_name + object_name + object_name = f"{self.folder_name}/{object_name}" else: object_name = object_name return object_name @@ -76,23 +77,46 @@ def create_presigned_url(self, object_name, expiration=3600): """ object_name = self.get_aws_object_key(object_name) # Generate a presigned URL for the S3 object + # The response contains the presigned URL + # https://{self.bucket_name}.s3.{region}.amazonaws.com/{object_key} try: url = self.s3_client.generate_presigned_url( "get_object", Params={"Bucket": self.bucket_name, "Key": object_name}, ExpiresIn=expiration, ) + base_url = urlunparse(urlparse(url)._replace(query="", fragment="")) + return base_url except ClientError as e: - logging.error(e) + logging.error(f"Error generating presigned URL: {e}") return None - # The response contains the presigned URL - # https://{self.bucket_name}.s3.{region}.amazonaws.com/{object_key} - return url - def save_file(self, file_path): + def is_url_valid(self, url): + """ + Validate the url's scheme, bucket name, and query parameters, etc. + """ + parsed_url = urlparse(url) + # Check the scheme + if parsed_url.scheme != "https": + return False + # Check the bucket name + if not parsed_url.path.startswith(f"/{self.bucket_name}/"): + return False + # Check unwanted query parameters + unwanted_query_params = ["AWSAccessKeyId", "Signature", "Expires"] + if parsed_url.query: + query_params = parse_qs(parsed_url.query) + for param in unwanted_query_params: + if param in query_params: + return False + return True + + def save_file_and_get_url(self, file_path): """ - Uploads a file to S3 and returns the presigned url + Uploads a file to S3 and returns the base url """ file_name = self.upload_file(file_path) - if file_name: - return file_name, self.create_presigned_url(file_name) + base_url = self.create_presigned_url(file_name) + if file_name and base_url: + if self.is_url_valid(base_url): + return file_name, base_url diff --git a/cellpack/autopack/DBRecipeHandler.py b/cellpack/autopack/DBRecipeHandler.py index af4e97e5..d1018203 100644 --- a/cellpack/autopack/DBRecipeHandler.py +++ b/cellpack/autopack/DBRecipeHandler.py @@ -572,6 +572,7 @@ def upload_recipe(self, recipe_meta_data, recipe_data): print(f"{recipe_id} is already in firestore") return recipe_to_save = self.upload_collections(recipe_meta_data, recipe_data) + recipe_to_save["recipe_path"] = self.db.create_path("recipes", recipe_id) self.upload_data("recipes", recipe_to_save, recipe_id) def upload_result_metadata(self, file_name, url): @@ -584,7 +585,7 @@ def upload_result_metadata(self, file_name, url): self.db.update_or_create( "results", file_name, - {"user": username, "timestamp": timestamp, "url": url.split("?")[0]}, + {"user": username, "timestamp": timestamp, "url": url}, ) @@ -630,6 +631,18 @@ def prep_db_doc_for_download(self, db_doc): def collect_docs_by_id(self, collection, id): return self.db.get_doc_by_id(collection, id) + def validate_input_recipe_path(self, path): + """ + Validates if the input path corresponds to a recipe path in the database. + Format of a recipe path: firebase:recipes/[RECIPE-ID] + """ + collection, id = self.db.get_collection_id_from_path(path) + recipe_path = self.db.get_value(collection, id, "recipe_path") + if not recipe_path: + raise ValueError( + f"No recipe found at the input path: '{path}'. Please ensure the recipe exists in the database and is spelled correctly. Expected path format: 'firebase:recipes/[RECIPE-ID]'" + ) + @staticmethod def _get_grad_and_obj(obj_data, obj_dict, grad_dict): """ diff --git a/cellpack/autopack/FirebaseHandler.py b/cellpack/autopack/FirebaseHandler.py index 69e1f0fe..96b91532 100644 --- a/cellpack/autopack/FirebaseHandler.py +++ b/cellpack/autopack/FirebaseHandler.py @@ -5,6 +5,9 @@ from dotenv import load_dotenv from google.cloud.exceptions import NotFound from cellpack.autopack.loaders.utils import read_json_file, write_json_file +from cellpack.autopack.interface_objects.default_values import ( + default_firebase_collection_names, +) class FirebaseHandler(object): @@ -65,10 +68,18 @@ def get_path_from_ref(doc): @staticmethod def get_collection_id_from_path(path): - # path example = firebase:composition/uid_1 - components = path.split(":")[1].split("/") - collection = components[0] - id = components[1] + try: + components = path.split(":")[1].split("/") + collection = components[0] + id = components[1] + if collection not in default_firebase_collection_names: + raise ValueError( + f"Invalid collection name: '{collection}'. Choose from: {default_firebase_collection_names}" + ) + except IndexError: + raise ValueError( + "Invalid path provided. Expected format: firebase:collection/id" + ) return collection, id # Create methods @@ -141,6 +152,12 @@ def get_doc_by_ref(self, path): collection, id = FirebaseHandler.get_collection_id_from_path(path) return self.get_doc_by_id(collection, id) + def get_value(self, collection, id, field): + doc, _ = self.get_doc_by_id(collection, id) + if doc is None: + return None + return doc[field] + # Update methods def update_doc(self, collection, id, data): doc_ref = self.db.collection(collection).document(id) diff --git a/cellpack/autopack/__init__.py b/cellpack/autopack/__init__.py index 18c7667b..2e5e44af 100755 --- a/cellpack/autopack/__init__.py +++ b/cellpack/autopack/__init__.py @@ -387,6 +387,7 @@ def load_file(filename, destination="", cache="geometries", force=None): if database_name == "firebase": db = DATABASE_IDS.handlers().get(database_name) db_handler = DBRecipeLoader(db) + db_handler.validate_input_recipe_path(filename) recipe_id = file_path.split("/")[-1] db_doc, _ = db_handler.collect_docs_by_id( collection="recipes", id=recipe_id diff --git a/cellpack/autopack/interface_objects/default_values.py b/cellpack/autopack/interface_objects/default_values.py index 18d82716..0e7d9d98 100644 --- a/cellpack/autopack/interface_objects/default_values.py +++ b/cellpack/autopack/interface_objects/default_values.py @@ -13,3 +13,11 @@ "mode_settings": {}, "weight_mode_settings": {}, } + +default_firebase_collection_names = [ + "composition", + "objects", + "gradients", + "recipes", + "results", +] diff --git a/cellpack/autopack/loaders/recipe_loader.py b/cellpack/autopack/loaders/recipe_loader.py index c8677ba0..fd21717e 100644 --- a/cellpack/autopack/loaders/recipe_loader.py +++ b/cellpack/autopack/loaders/recipe_loader.py @@ -196,7 +196,7 @@ def _read(self, resolve_inheritance=True): atomic=reps.get("atomic", None), packing=reps.get("packing", None), ) - # the key "all_partners" exists in obj["partners"] if the recipe is downloaded from a remote db + # the key "all_partners" already exists in obj["partners"] if the recipe is downloaded from firebase partner_settings = ( [] if ( diff --git a/cellpack/autopack/upy/simularium/simularium_helper.py b/cellpack/autopack/upy/simularium/simularium_helper.py index 031dccd8..1caa908e 100644 --- a/cellpack/autopack/upy/simularium/simularium_helper.py +++ b/cellpack/autopack/upy/simularium/simularium_helper.py @@ -1413,10 +1413,10 @@ def store_result_file(file_path, storage=None): handler = DATABASE_IDS.handlers().get(storage) initialized_handler = handler( bucket_name="cellpack-results", - sub_folder_name="simularium/", + sub_folder_name="simularium", region_name="us-west-2", ) - file_name, url = initialized_handler.save_file(file_path) + file_name, url = initialized_handler.save_file_and_get_url(file_path) simulariumHelper.store_metadata(file_name, url, db="firebase") return file_name, url diff --git a/cellpack/tests/test_aws_handler.py b/cellpack/tests/test_aws_handler.py new file mode 100644 index 00000000..6aaad87c --- /dev/null +++ b/cellpack/tests/test_aws_handler.py @@ -0,0 +1,107 @@ +import boto3 +from unittest.mock import patch +from moto import mock_aws +from cellpack.autopack.AWSHandler import AWSHandler + + +@patch("cellpack.autopack.AWSHandler.boto3.client") +def test_create_session(mock_client): + with mock_aws(): + aws_handler = AWSHandler( + bucket_name="test_bucket", + sub_folder_name="test_folder", + region_name="us-west-2", + ) + assert aws_handler.s3_client is not None + mock_client.assert_called_once_with( + "s3", + endpoint_url="https://s3.us-west-2.amazonaws.com", + region_name="us-west-2", + ) + + +def test_get_aws_object_key(): + with mock_aws(): + aws_handler = AWSHandler( + bucket_name="test_bucket", + sub_folder_name="test_folder", + region_name="us-west-2", + ) + object_key = aws_handler.get_aws_object_key("test_file") + assert object_key == "test_folder/test_file" + + +def test_upload_file(): + with mock_aws(): + aws_handler = AWSHandler( + bucket_name="test_bucket", + sub_folder_name="test_folder", + region_name="us-west-2", + ) + s3 = boto3.client("s3", region_name="us-west-2") + s3.create_bucket( + Bucket="test_bucket", + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + with open("test_file.txt", "w") as file: + file.write("test file") + file_name = aws_handler.upload_file("test_file.txt") + assert file_name == "test_file.txt" + + +def test_create_presigned_url(): + with mock_aws(), patch.object(AWSHandler, "_s3_client") as mock_client: + presigned_url = "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt?query=string" + mock_client.generate_presigned_url.return_value = presigned_url + aws_handler = AWSHandler( + bucket_name="test_bucket", + sub_folder_name="test_folder", + region_name="us-west-2", + ) + s3 = boto3.client("s3", region_name="us-west-2") + s3.create_bucket( + Bucket="test_bucket", + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + with open("test_file.txt", "w") as file: + file.write("test file") + aws_handler.upload_file("test_file.txt") + url = aws_handler.create_presigned_url("test_file.txt") + assert url is not None + assert url.startswith( + "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt" + ) + + +def test_is_url_valid(): + with mock_aws(), patch.object(AWSHandler, "_s3_client") as mock_client: + presigned_url = "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt?query=string" + mock_client.generate_presigned_url.return_value = presigned_url + aws_handler = AWSHandler( + bucket_name="test_bucket", + sub_folder_name="test_folder", + region_name="us-west-2", + ) + s3 = boto3.client("s3", region_name="us-west-2") + s3.create_bucket( + Bucket="test_bucket", + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + with open("test_file.txt", "w") as file: + file.write("test file") + aws_handler.upload_file("test_file.txt") + url = aws_handler.create_presigned_url("test_file.txt") + assert aws_handler.is_url_valid(url) is True + assert aws_handler.is_url_valid("invalid_url") is False + assert ( + aws_handler.is_url_valid( + "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt" + ) + is True + ) + assert ( + aws_handler.is_url_valid( + "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt?AWSAccessKeyId=1234" + ) + is False + ) diff --git a/setup.py b/setup.py index 8726cc88..cbe54590 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ "trimesh>=3.9.34", "deepdiff>=5.5.0", "python-dotenv>=1.0.0", + "moto>=5.0.2", ] extra_requirements = {