Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhanced simularium url and firebase recipe path validation #230

Merged
merged 10 commits into from
Mar 21, 2024
42 changes: 33 additions & 9 deletions cellpack/autopack/AWSHandler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from urllib.parse import parse_qs, urlparse, urlunparse

import boto3
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -40,7 +41,7 @@ def _create_session(self, region_name):

def get_aws_object_key(self, object_name):
if self.folder_name is not None:
object_name = self.folder_name + object_name
object_name = f"{self.folder_name}/{object_name}"
else:
object_name = object_name
return object_name
Expand Down Expand Up @@ -76,23 +77,46 @@ def create_presigned_url(self, object_name, expiration=3600):
"""
object_name = self.get_aws_object_key(object_name)
# Generate a presigned URL for the S3 object
# The response contains the presigned URL
# https://{self.bucket_name}.s3.{region}.amazonaws.com/{object_key}
try:
url = self.s3_client.generate_presigned_url(
"get_object",
Params={"Bucket": self.bucket_name, "Key": object_name},
ExpiresIn=expiration,
)
base_url = urlunparse(urlparse(url)._replace(query="", fragment=""))
return base_url
except ClientError as e:
logging.error(e)
logging.error(f"Error generating presigned URL: {e}")
return None
# The response contains the presigned URL
# https://{self.bucket_name}.s3.{region}.amazonaws.com/{object_key}
return url

def save_file(self, file_path):
def is_url_valid(self, url):
"""
Validate the url's scheme, bucket name, and query parameters, etc.
"""
parsed_url = urlparse(url)
# Check the scheme
if parsed_url.scheme != "https":
return False
# Check the bucket name
if not parsed_url.path.startswith(f"/{self.bucket_name}/"):
return False
# Check unwanted query parameters
unwanted_query_params = ["AWSAccessKeyId", "Signature", "Expires"]
if parsed_url.query:
query_params = parse_qs(parsed_url.query)
for param in unwanted_query_params:
if param in query_params:
return False
return True

def save_file_and_get_url(self, file_path):
"""
Uploads a file to S3 and returns the presigned url
Uploads a file to S3 and returns the base url
"""
file_name = self.upload_file(file_path)
if file_name:
return file_name, self.create_presigned_url(file_name)
base_url = self.create_presigned_url(file_name)
if file_name and base_url:
if self.is_url_valid(base_url):
return file_name, base_url
15 changes: 14 additions & 1 deletion cellpack/autopack/DBRecipeHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def upload_recipe(self, recipe_meta_data, recipe_data):
print(f"{recipe_id} is already in firestore")
return
recipe_to_save = self.upload_collections(recipe_meta_data, recipe_data)
recipe_to_save["recipe_path"] = self.db.create_path("recipes", recipe_id)
self.upload_data("recipes", recipe_to_save, recipe_id)

def upload_result_metadata(self, file_name, url):
Expand All @@ -584,7 +585,7 @@ def upload_result_metadata(self, file_name, url):
self.db.update_or_create(
"results",
file_name,
{"user": username, "timestamp": timestamp, "url": url.split("?")[0]},
{"user": username, "timestamp": timestamp, "url": url},
)


Expand Down Expand Up @@ -630,6 +631,18 @@ def prep_db_doc_for_download(self, db_doc):
def collect_docs_by_id(self, collection, id):
return self.db.get_doc_by_id(collection, id)

def validate_input_recipe_path(self, path):
"""
Validates if the input path corresponds to a recipe path in the database.
Format of a recipe path: firebase:recipes/[RECIPE-ID]
"""
collection, id = self.db.get_collection_id_from_path(path)
recipe_path = self.db.get_value(collection, id, "recipe_path")
if not recipe_path:
raise ValueError(
f"No recipe found at the input path: '{path}'. Please ensure the recipe exists in the database and is spelled correctly. Expected path format: 'firebase:recipes/[RECIPE-ID]'"
)

@staticmethod
def _get_grad_and_obj(obj_data, obj_dict, grad_dict):
"""
Expand Down
25 changes: 21 additions & 4 deletions cellpack/autopack/FirebaseHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from dotenv import load_dotenv
from google.cloud.exceptions import NotFound
from cellpack.autopack.loaders.utils import read_json_file, write_json_file
from cellpack.autopack.interface_objects.default_values import (
default_firebase_collection_names,
)


class FirebaseHandler(object):
Expand Down Expand Up @@ -65,10 +68,18 @@ def get_path_from_ref(doc):

@staticmethod
def get_collection_id_from_path(path):
# path example = firebase:composition/uid_1
components = path.split(":")[1].split("/")
collection = components[0]
id = components[1]
try:
components = path.split(":")[1].split("/")
collection = components[0]
id = components[1]
if collection not in default_firebase_collection_names:
raise ValueError(
f"Invalid collection name: '{collection}'. Choose from: {default_firebase_collection_names}"
)
except IndexError:
raise ValueError(
"Invalid path provided. Expected format: firebase:collection/id"
)
return collection, id

# Create methods
Expand Down Expand Up @@ -141,6 +152,12 @@ def get_doc_by_ref(self, path):
collection, id = FirebaseHandler.get_collection_id_from_path(path)
return self.get_doc_by_id(collection, id)

def get_value(self, collection, id, field):
doc, _ = self.get_doc_by_id(collection, id)
if doc is None:
return None
return doc[field]

# Update methods
def update_doc(self, collection, id, data):
doc_ref = self.db.collection(collection).document(id)
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def load_file(filename, destination="", cache="geometries", force=None):
if database_name == "firebase":
db = DATABASE_IDS.handlers().get(database_name)
db_handler = DBRecipeLoader(db)
db_handler.validate_input_recipe_path(filename)
recipe_id = file_path.split("/")[-1]
db_doc, _ = db_handler.collect_docs_by_id(
collection="recipes", id=recipe_id
Expand Down
8 changes: 8 additions & 0 deletions cellpack/autopack/interface_objects/default_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@
"mode_settings": {},
"weight_mode_settings": {},
}

default_firebase_collection_names = [
"composition",
"objects",
"gradients",
"recipes",
"results",
]
2 changes: 1 addition & 1 deletion cellpack/autopack/loaders/recipe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _read(self, resolve_inheritance=True):
atomic=reps.get("atomic", None),
packing=reps.get("packing", None),
)
# the key "all_partners" exists in obj["partners"] if the recipe is downloaded from a remote db
# the key "all_partners" already exists in obj["partners"] if the recipe is downloaded from firebase
partner_settings = (
[]
if (
Expand Down
4 changes: 2 additions & 2 deletions cellpack/autopack/upy/simularium/simularium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,10 +1413,10 @@ def store_result_file(file_path, storage=None):
handler = DATABASE_IDS.handlers().get(storage)
initialized_handler = handler(
bucket_name="cellpack-results",
sub_folder_name="simularium/",
sub_folder_name="simularium",
region_name="us-west-2",
)
file_name, url = initialized_handler.save_file(file_path)
file_name, url = initialized_handler.save_file_and_get_url(file_path)
simulariumHelper.store_metadata(file_name, url, db="firebase")
return file_name, url

Expand Down
107 changes: 107 additions & 0 deletions cellpack/tests/test_aws_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import boto3
from unittest.mock import patch
from moto import mock_aws
from cellpack.autopack.AWSHandler import AWSHandler


@patch("cellpack.autopack.AWSHandler.boto3.client")
def test_create_session(mock_client):
with mock_aws():
aws_handler = AWSHandler(
bucket_name="test_bucket",
sub_folder_name="test_folder",
region_name="us-west-2",
)
assert aws_handler.s3_client is not None
mock_client.assert_called_once_with(
"s3",
endpoint_url="https://s3.us-west-2.amazonaws.com",
region_name="us-west-2",
)


def test_get_aws_object_key():
with mock_aws():
aws_handler = AWSHandler(
bucket_name="test_bucket",
sub_folder_name="test_folder",
region_name="us-west-2",
)
object_key = aws_handler.get_aws_object_key("test_file")
assert object_key == "test_folder/test_file"


def test_upload_file():
with mock_aws():
aws_handler = AWSHandler(
bucket_name="test_bucket",
sub_folder_name="test_folder",
region_name="us-west-2",
)
s3 = boto3.client("s3", region_name="us-west-2")
s3.create_bucket(
Bucket="test_bucket",
CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
)
with open("test_file.txt", "w") as file:
file.write("test file")
file_name = aws_handler.upload_file("test_file.txt")
assert file_name == "test_file.txt"


def test_create_presigned_url():
with mock_aws(), patch.object(AWSHandler, "_s3_client") as mock_client:
presigned_url = "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt?query=string"
mock_client.generate_presigned_url.return_value = presigned_url
aws_handler = AWSHandler(
bucket_name="test_bucket",
sub_folder_name="test_folder",
region_name="us-west-2",
)
s3 = boto3.client("s3", region_name="us-west-2")
s3.create_bucket(
Bucket="test_bucket",
CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
)
with open("test_file.txt", "w") as file:
file.write("test file")
aws_handler.upload_file("test_file.txt")
url = aws_handler.create_presigned_url("test_file.txt")
assert url is not None
assert url.startswith(
"https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt"
)


def test_is_url_valid():
with mock_aws(), patch.object(AWSHandler, "_s3_client") as mock_client:
presigned_url = "https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt?query=string"
mock_client.generate_presigned_url.return_value = presigned_url
aws_handler = AWSHandler(
bucket_name="test_bucket",
sub_folder_name="test_folder",
region_name="us-west-2",
)
s3 = boto3.client("s3", region_name="us-west-2")
s3.create_bucket(
Bucket="test_bucket",
CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
)
with open("test_file.txt", "w") as file:
file.write("test file")
aws_handler.upload_file("test_file.txt")
url = aws_handler.create_presigned_url("test_file.txt")
assert aws_handler.is_url_valid(url) is True
assert aws_handler.is_url_valid("invalid_url") is False
assert (
aws_handler.is_url_valid(
"https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt"
)
is True
)
assert (
aws_handler.is_url_valid(
"https://s3.us-west-2.amazonaws.com/test_bucket/test_folder/test_file.txt?AWSAccessKeyId=1234"
)
is False
)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"trimesh>=3.9.34",
"deepdiff>=5.5.0",
"python-dotenv>=1.0.0",
"moto>=5.0.2",
]

extra_requirements = {
Expand Down
Loading