From 4f9bdc3726876c28566593be5047ca04a3574e2d Mon Sep 17 00:00:00 2001 From: Ruge Li Date: Mon, 23 Oct 2023 12:58:42 -0700 Subject: [PATCH] refactor AWS and firebase handler --- cellpack/autopack/AWSHandler.py | 19 +++- cellpack/autopack/FirebaseHandler.py | 160 +++++++++++++++++---------- 2 files changed, 117 insertions(+), 62 deletions(-) diff --git a/cellpack/autopack/AWSHandler.py b/cellpack/autopack/AWSHandler.py index 878638bd0..c93a6098f 100644 --- a/cellpack/autopack/AWSHandler.py +++ b/cellpack/autopack/AWSHandler.py @@ -10,6 +10,10 @@ class AWSHandler(object): Handles all the AWS S3 operations """ + # class attributes + _session_created = False + _s3_client = None + def __init__( self, bucket_name, @@ -18,12 +22,21 @@ def __init__( ): self.bucket_name = bucket_name self.folder_name = sub_folder_name - session = boto3.Session() - self.s3_client = session.client( + # Create a session if one does not exist + if not AWSHandler._session_created: + self._create_session(region_name) + AWSHandler._session_created = True + else: + # use the existing session + self.s3_client = AWSHandler._s3_client + + def _create_session(self, region_name): + AWSHandler._s3_client = boto3.client( "s3", endpoint_url=f"https://s3.{region_name}.amazonaws.com", region_name=region_name, ) + self.s3_client = AWSHandler._s3_client def get_aws_object_key(self, object_name): if self.folder_name is not None: @@ -82,4 +95,4 @@ def save_file(self, file_path): """ file_name = self.upload_file(file_path) if file_name: - return self.create_presigned_url(file_name) + return file_name, self.create_presigned_url(file_name) diff --git a/cellpack/autopack/FirebaseHandler.py b/cellpack/autopack/FirebaseHandler.py index 469fd7871..8f1942388 100644 --- a/cellpack/autopack/FirebaseHandler.py +++ b/cellpack/autopack/FirebaseHandler.py @@ -1,6 +1,7 @@ -import firebase_admin import ast +import firebase_admin from firebase_admin import credentials, firestore +from google.cloud.exceptions import NotFound from cellpack.autopack.loaders.utils import read_json_file, write_json_file @@ -9,42 +10,27 @@ class FirebaseHandler(object): Retrieve data and perform common tasks when working with firebase. """ + # use class attributes to maintain a consistent state across all instances + _initialized = False + _db = None + def __init__(self): - cred_path = FirebaseHandler.get_creds() - login = credentials.Certificate(cred_path) - firebase_admin.initialize_app(login) - self.db = firestore.client() + # check if firebase is already initialized + if not FirebaseHandler._initialized: + cred_path = FirebaseHandler.get_creds() + login = credentials.Certificate(cred_path) + firebase_admin.initialize_app(login) + FirebaseHandler._initialized = True + FirebaseHandler._db = firestore.client() + + self.db = FirebaseHandler._db self.name = "firebase" + # common utility methods @staticmethod def doc_to_dict(doc): return doc.to_dict() - @staticmethod - def write_creds_path(): - path = ast.literal_eval(input("provide path to firebase credentials: ")) - data = read_json_file(path) - if data is None: - raise ValueError("The path to your credentials doesn't exist") - firebase_cred = {"firebase": data} - creds = read_json_file("./.creds") - if creds is None: - write_json_file("./.creds", firebase_cred) - else: - creds["firebase"] = data - write_json_file("./.creds", creds) - return firebase_cred - - @staticmethod - def get_creds(): - creds = read_json_file("./.creds") - if creds is None or "firebase" not in creds: - creds = FirebaseHandler.write_creds_path() - return creds["firebase"] - - def db_name(self): - return self.name - @staticmethod def doc_id(doc): return doc.id @@ -53,6 +39,10 @@ def doc_id(doc): def create_path(collection, doc_id): return f"firebase:{collection}/{doc_id}" + @staticmethod + def create_timestamp(): + return firestore.SERVER_TIMESTAMP + @staticmethod def get_path_from_ref(doc): return doc.path @@ -65,24 +55,41 @@ def get_collection_id_from_path(path): id = components[1] return collection, id - @staticmethod - def update_reference_on_doc(doc_ref, index, new_item_ref): - doc_ref.update({index: new_item_ref}) + # Create methods + def set_doc(self, collection, id, data): + doc, doc_ref = self.get_doc_by_id(collection, id) + if not doc: + doc_ref = self.db.collection(collection).document(id) + doc_ref.set(data) + print(f"successfully uploaded to path: {doc_ref.path}") + return doc_ref + else: + print( + f"ERROR: {doc_ref.path} already exists. If uploading new data, provide a unique recipe name." + ) + return + + def upload_doc(self, collection, data): + return self.db.collection(collection).add(data) + # Read methods @staticmethod - def update_elements_in_array(doc_ref, index, new_item_ref, remove_item): - doc_ref.update({index: firestore.ArrayRemove([remove_item])}) - doc_ref.update({index: firestore.ArrayUnion([new_item_ref])}) + def get_creds(): + creds = read_json_file("./.creds") + if creds is None or "firebase" not in creds: + creds = FirebaseHandler.write_creds_path() + return creds["firebase"] @staticmethod - def is_reference(path): - if not isinstance(path, str): - return False - if path is None: - return False - if path.startswith("firebase:"): - return True - return False + def get_username(): + creds = read_json_file("./.creds") + try: + return creds["username"] + except KeyError: + raise ValueError("No username found in .creds file") + + def db_name(self): + return self.name def get_doc_by_name(self, collection, name): db = self.db @@ -90,9 +97,9 @@ def get_doc_by_name(self, collection, name): docs = data_ref.where("name", "==", name).get() # docs is an array return docs - # `doc` is a DocumentSnapshot object - # `doc_ref` is a DocumentReference object to perform operations on the doc def get_doc_by_id(self, collection, id): + # `doc` is a DocumentSnapshot object + # `doc_ref` is a DocumentReference object to perform operations on the doc doc_ref = self.db.collection(collection).document(id) doc = doc_ref.get() if doc.exists: @@ -104,21 +111,56 @@ def get_doc_by_ref(self, path): collection, id = FirebaseHandler.get_collection_id_from_path(path) return self.get_doc_by_id(collection, id) - def set_doc(self, collection, id, data): - doc, doc_ref = self.get_doc_by_id(collection, id) - if not doc: - doc_ref = self.db.collection(collection).document(id) - doc_ref.set(data) - print(f"successfully uploaded to path: {doc_ref.path}") - return doc_ref + # Update methods + def update_doc(self, collection, id, data): + doc_ref = self.db.collection(collection).document(id) + doc_ref.update(data) + print(f"successfully updated to path: {doc_ref.path}") + return doc_ref + + @staticmethod + def update_reference_on_doc(doc_ref, index, new_item_ref): + doc_ref.update({index: new_item_ref}) + + @staticmethod + def update_elements_in_array(doc_ref, index, new_item_ref, remove_item): + doc_ref.update({index: firestore.ArrayRemove([remove_item])}) + doc_ref.update({index: firestore.ArrayUnion([new_item_ref])}) + + def update_or_create(self, collection, id, data): + """ + If the input id exists, update the doc. If not, create a new file. + """ + try: + self.update_doc(collection, id, data) + except NotFound: + self.set_doc(collection, id, data) + + # other utils + @staticmethod + def write_creds_path(): + path = ast.literal_eval(input("provide path to firebase credentials: ")) + data = read_json_file(path) + if data is None: + raise ValueError("The path to your credentials doesn't exist") + firebase_cred = {"firebase": data} + creds = read_json_file("./.creds") + if creds is None: + write_json_file("./.creds", firebase_cred) else: - print( - f"ERROR: {doc_ref.path} already exists. If uploading new data, provide a unique recipe name." - ) - return + creds["firebase"] = data + write_json_file("./.creds", creds) + return firebase_cred - def upload_doc(self, collection, data): - return self.db.collection(collection).add(data) + @staticmethod + def is_reference(path): + if not isinstance(path, str): + return False + if path is None: + return False + if path.startswith("firebase:"): + return True + return False @staticmethod def is_firebase_obj(obj):