Skip to content

Commit

Permalink
fixes #51: Add inference request test
Browse files Browse the repository at this point in the history
fixes #51: inference test with Quart.test_client

fixes #51: Correct lint ruff error and tests
  • Loading branch information
MaxenceGui committed Apr 8, 2024
1 parent e581703 commit e340108
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 171 deletions.
7 changes: 0 additions & 7 deletions .env.template
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
NACHET_AZURE_STORAGE_CONNECTION_STRING=

NACHET_MODEL_ENDPOINT_REST_URL=
NACHET_MODEL_ENDPOINT_ACCESS_KEY=
NACHET_SWIN_ENDPOINT=
NACHET_SWIN_ACCESS_KEY=
NACHET_SEED_DETECTOR_ENDPOINT=
NACHET_SEED_DETECTOR_ACCESS_KEY=

NACHET_DATA=
NACHET_SUBSCRIPTION_ID=
NACHET_RESOURCE_GROUP=
Expand Down
99 changes: 49 additions & 50 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,53 +30,9 @@
PIPELINE_VERSION = os.getenv("NACHET_BLOB_PIPELINE_VERSION")
PIPELINE_BLOB_NAME = os.getenv("NACHET_BLOB_PIPELINE_NAME")

endpoint_url_regex = r"^https://.*\/score$"
endpoint_url = os.getenv("NACHET_MODEL_ENDPOINT_REST_URL")
sd_endpoint = os.getenv("NACHET_SEED_DETECTOR_ENDPOINT")
swin_endpoint = os.getenv("NACHET_SWIN_ENDPOINT")

endpoint_api_key = os.getenv("NACHET_MODEL_ENDPOINT_ACCESS_KEY")
sd_api_key = os.getenv("NACHET_SEED_DETECTOR_ACCESS_KEY")
swin_api_key = os.getenv("NACHET_SWIN_ACCESS_KEY")

NACHET_DATA = os.getenv("NACHET_DATA")
NACHET_MODEL = os.getenv("NACHET_MODEL")

# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if endpoint_url is None:
raise ServerError("Missing environment variable: NACHET_MODEL_ENDPOINT_REST_URL")

if sd_endpoint is None:
raise ServerError("Missing environment variable: NACHET_SEED_DETECTOR")

if swin_endpoint is None:
raise ServerError("Missing environment variable: NACHET_SWIN_ENDPOINT")

if endpoint_api_key is None:
raise ServerError("Missing environment variables: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

if sd_api_key is None:
raise ServerError("Missing environment variables: NACHET_SEED_DETECTOR_ACCESS_KEY")

if swin_api_key is None:
raise ServerError("Missing environment variables: NACHET_SWIN_ACCESS_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(endpoint_url_regex, endpoint_url)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

if not bool(re.match(endpoint_url_regex, sd_endpoint)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

if not bool(re.match(endpoint_url_regex, swin_endpoint)):
raise ServerError("Incorrect environment variable: NACHET_MODEL_ENDPOINT_ACCESS_KEY")

Model = namedtuple(
'Model',
[
Expand All @@ -87,7 +43,8 @@
'inference_function',
'content_type',
'deployment_platform',
])
]
)

CACHE = {
"seeds": None,
Expand Down Expand Up @@ -216,7 +173,7 @@ async def inference_request():
# Validate image header
if not header.startswith("data:image/"):
return jsonify(["Invalid image header"]), 400

image_bytes = base64.b64decode(encoded_data)
container_client = await azure_storage_api.mount_container(
connection_string, container_name, create_container=True
Expand Down Expand Up @@ -245,15 +202,13 @@ async def inference_request():
cache_json_result[-1], imageDims
)

with open("inference_result.json", "w+") as f:
json.dump(processed_result_json, f, indent=4)

except urllib.error.HTTPError as error:
print(error)
return jsonify(["endpoint cannot be reached" + str(error.code)]), 400

# upload the inference results to the user's container as async task
result_json_string = json.dumps(processed_result_json)

app.add_background_task(
azure_storage_api.upload_inference_result,
container_client,
Expand Down Expand Up @@ -313,6 +268,31 @@ async def health():
return "ok", 200


@app.get("/test")
async def test():
# Build test pipeline
CACHE["endpoints"] = [
{
"pipeline_name": "test_pipeline",
"models": ["test_model1"]
}
]
# Built test model
m = Model(
model_module.request_inference_from_test,
"test_model1",
"http://localhost:8080/test_model1",
"test_api_key",
None,
"application/json",
"test_platform"
)

CACHE["pipelines"]["test_pipeline"] = (m,)

return CACHE["endpoints"], 200


async def fetch_json(repo_URL, key, file_path, mock=False):
"""
Fetches JSON document from a GitHub repository and caches it
Expand All @@ -330,6 +310,7 @@ async def fetch_json(repo_URL, key, file_path, mock=False):
except Exception as e:
raise ValueError(str(e))


async def get_pipelines(mock:bool = False):
"""
Retrieves the pipelines from the Azure storage API.
Expand Down Expand Up @@ -369,17 +350,35 @@ async def get_pipelines(mock:bool = False):

return result_json.get("pipelines")


async def data_factory(**kwargs):
return {
"input_data": kwargs,
}


@app.before_serving
async def before_serving():
try:
# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if FERNET_KEY is None:
raise ServerError("Missing environment variable: FERNET_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

CACHE["seeds"] = await fetch_json(NACHET_DATA, "seeds", "seeds/all.json")
# CACHE["endpoints"] = await fetch_json(NACHET_MODEL, "endpoints", "model_endpoints_metadata.json")
CACHE["endpoints"] = await get_pipelines() # mock=True
CACHE["endpoints"] = await get_pipelines()

except ServerError as e:
print(e)
raise ServerError(str(e))

except Exception as e:
print(e)
raise ServerError("Failed to retrieve data from the repository")
Expand Down
46 changes: 45 additions & 1 deletion model_inference/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,48 @@ async def request_inference_from_nachet_6seeds(model: namedtuple, previous_resul

except Exception as e:
raise InferenceRequestError(f"An error occurred while processing the request:\n {str(e)}")


async def request_inference_from_test(model: namedtuple, previous_result: str):
"""
Requests a test case inference.
Args:
model (namedtuple): The model to use for the test inference.
previous_result (str): The previous result to pass to the model.
Returns:
dict: The result of the inference as a JSON object.
Raises:
InferenceRequestError: If an error occurs while processing the request.
"""
try:
if previous_result == '':
raise Exception("Test error")
print(f"processing test request for {model.name} with {type(previous_result)} arguments")
return [
{
"filename": "test_image.jpg",
"boxes": [
{
"box": {
"topX": 0.078,
"topY": 0.068,
"bottomX": 0.86,
"bottomY": 0.56
},
"label": "test_label",
"score": 1.0,
"topN": [
{
"label": "test_label",
"score": 1.0,
},
],
}
]
}
]

except Exception as e:
raise Exception(f"An error occurred while processing the request:\n {str(e)}")
20 changes: 12 additions & 8 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tests.test_azure_storage_api import TestMountContainerFunction, TestGetBlob, testGetPipeline
from tests.test_inference_request import TestInferenceRequest
from tests.test_health_request import TestQuartHealth

a = TestMountContainerFunction()
a.test_mount_existing_container()
Expand All @@ -14,11 +15,14 @@
c.test_get_pipeline_info_unsuccessful()
c.test_get_pipeline_info_successful()

d = TestInferenceRequest()
d.setUp()
d.test_inference_request_successful()
d.test_inference_request_unsuccessfull()
d.test_inference_request_missing_argument()
d.test_inference_request_wrong_pipeline_name()
d.test_inference_request_wrong_header()
d.tearDown()
d = TestQuartHealth()
d.test_health()

e = TestInferenceRequest()
e.setUp()
e.test_inference_request_successful()
e.test_inference_request_unsuccessfull()
e.test_inference_request_missing_argument()
e.test_inference_request_wrong_pipeline_name()
e.test_inference_request_wrong_header()
e.tearDown()
Binary file added tests/1310_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions tests/test_health_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
import asyncio

from app import app

class TestQuartHealth(unittest.TestCase):
def test_health(self):
test = app.test_client()

loop = asyncio.get_event_loop()
response = loop.run_until_complete(
test.get('/health')
)
print(response.status_code)
self.assertEqual(response.status_code, 200)

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit e340108

Please sign in to comment.