diff --git a/app.py b/app.py index 99f7b04..46a9ca9 100644 --- a/app.py +++ b/app.py @@ -185,6 +185,23 @@ async def before_serving(): raise +@app.get("/get-user-id") +async def get_user_id() : + """ + Returns the user id + """ + try: + data = await request.get_json() + email = data["email"] + + user_id = datastore.get_user_id(email) + + return jsonify(user_id), 200 + except (KeyError, TypeError, ValueError, datastore.DatastoreError) as error: + print(error) + return jsonify([f"GetUserIdError: {str(error)}"]), 400 + + @app.post("/del") async def delete_directory(): """ @@ -342,8 +359,7 @@ async def inference_request(): container_name = data["container_name"] imageDims = data["imageDims"] image_base64 = data["image"] - #user_id = data.get["user_id"] - email = "example@gmail.com" + user_id = data["userId"] area_ratio = data.get("area_ratio", 0.5) color_format = data.get("color_format", "hex") @@ -377,12 +393,10 @@ async def inference_request(): # Open db connection connection = datastore.get_connection() cursor = datastore.get_cursor(connection) - - user = await datastore.validate_user(cursor, email, CONNECTION_STRING) image_hash_value = await azure_storage.generate_hash(image_bytes) picture_id = await datastore.get_picture_id( - cursor, user.id, image_hash_value, container_client + cursor, user_id, image_hash_value, container_client ) pipeline = pipelines_endpoints.get(pipeline_name) @@ -409,7 +423,7 @@ async def inference_request(): result_json_string, image_hash_value, ) - saved_result_json = await datastore.save_inference_result(cursor, user.id, processed_result_json[0], picture_id, pipeline_name, 1) + saved_result_json = await datastore.save_inference_result(cursor, user_id, processed_result_json[0], picture_id, pipeline_name, 1) datastore.end_query(connection, cursor) diff --git a/storage/datastore_storage_api.py b/storage/datastore_storage_api.py index 24d7745..5c25e10 100644 --- a/storage/datastore_storage_api.py +++ b/storage/datastore_storage_api.py @@ -24,6 +24,9 @@ class SeedNotFoundError(DatastoreError): class GetPipelinesError(DatastoreError): pass +class UserNotFoundError(DatastoreError): + pass + def get_connection() : return db.connect_db() @@ -58,7 +61,6 @@ def get_all_seeds_names() -> list: except Exception as error: # TODO modify Exception for more specific exception raise SeedNotFoundError(error.args[0]) - def get_seeds(expression: str) -> list: """ Return a list of all seed that contains the expression @@ -67,7 +69,17 @@ def get_seeds(expression: str) -> list: cursor = get_cursor(connection) return list(filter(lambda x: expression in x, get_all_seeds_names(cursor))) - +def get_user_id(email: str) -> str: + """ + Return the user_id of the user + """ + connection = get_connection() + cursor = get_cursor(connection) + if user_datastore.is_user_registered(cursor, email): + return user_datastore.get_user_id(cursor, email) + else : + raise UserNotFoundError("User not found") + async def validate_user(cursor, email: str, connection_string) -> datastore.User: """ Return True if user is valid, False otherwise