diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index dafa512..6d1152e 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -7,6 +7,8 @@ "features": { "ghcr.io/devcontainers/features/azure-cli:1": {} }, + // uncomment the next line for LAN model access + // "runArgs": ["--network=host"], // Features to add to the dev container. More info: https://containers.dev/features. // "features": {}, @@ -15,7 +17,7 @@ // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": "pip3 install --user -r requirements.txt && pip install --upgrade pydantic", + "postCreateCommand": "pip3 install --user -r requirements.txt", // && pip install --upgrade pydantic", // Configure tool-specific properties. "customizations": { diff --git a/.env.template b/.env.template index 5a8a7c0..80b2e9e 100644 --- a/.env.template +++ b/.env.template @@ -12,3 +12,6 @@ NACHET_MAX_CONTENT_LENGTH= NACHET_VALID_EXTENSION= NACHET_VALID_DIMENSION= DEV_USER_EMAIL= +NACHET_ENV= +NACHET_FRONTEND_PUBLIC_URL= +NACHET_FRONTEND_DEV_URL= diff --git a/app.py b/app.py index 6f70626..55d3b8c 100644 --- a/app.py +++ b/app.py @@ -88,6 +88,10 @@ class MaxContentLengthWarning(APIWarnings): PIPELINE_BLOB_NAME = os.getenv("NACHET_BLOB_PIPELINE_NAME") NACHET_DATA = os.getenv("NACHET_DATA") +ENVIRONMENT = os.getenv("NACHET_ENV") +NACHET_FRONTEND_DEV_URL = os.getenv("NACHET_FRONTEND_DEV_URL") +NACHET_FRONTEND_PUBLIC_URL = os.getenv("NACHET_FRONTEND_PUBLIC_URL") +ALLOWED_URL = NACHET_FRONTEND_DEV_URL if ENVIRONMENT == "local" else NACHET_FRONTEND_PUBLIC_URL try: VALID_EXTENSION = json.loads(os.getenv("NACHET_VALID_EXTENSION")) @@ -129,8 +133,15 @@ class MaxContentLengthWarning(APIWarnings): CACHE = {"seeds": None, "endpoints": None, "pipelines": {}, "validators": []} +cors_settings = { + "allow_origin": [ALLOWED_URL], + "allow_methods": ["GET", "POST", "OPTIONS"], + "allow_credentials": True, + "max_age": 86400 +} + app = Quart(__name__) -app = cors(app, allow_origin="*", allow_methods=["GET", "POST", "OPTIONS"]) +app = cors(app, **cors_settings) app.config["MAX_CONTENT_LENGTH"] = MAX_CONTENT_LENGTH_MEGABYTES * 1024 * 1024 @@ -191,12 +202,12 @@ async def get_user_id(): """ try: email = None - internal = True # set flag to false if developing locally if "jxVouchCookie" in request.cookies: - email = decode_vouch_cookie(request.cookies["jxVouchCookie"]) + decoded_cookie = decode_vouch_cookie(request.cookies["jxVouchCookie"]) + email = decoded_cookie["CustomClaims"]["email"] - if not internal and not email: # only allow internal requests to bypass email + if ENVIRONMENT == "local" and not email: # only allow local dev requests to bypass email data = await request.get_json() email = data.get("email") diff --git a/model/__init__.py b/model/__init__.py index 6debf77..211759a 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -6,6 +6,8 @@ request_function = { "swinv1-base-dataaugv2-1": request_inference_from_swin, "seed-detector-1": request_inference_from_seed_detector, + "swinv1-base-dataaugv2-2": request_inference_from_swin, + "seed-detector-rcnn-1": request_inference_from_seed_detector, "test": request_inference_from_test, "m-14of15seeds-6seedsmag": request_inference_from_nachet_6seeds } diff --git a/model/seed_detector.py b/model/seed_detector.py index a0504b0..b008f90 100644 --- a/model/seed_detector.py +++ b/model/seed_detector.py @@ -92,11 +92,12 @@ async def request_inference_from_seed_detector(model: namedtuple, previous_resul } body = str.encode(json.dumps(data)) - req = Request(model.endpoint, body, headers) + req = Request(model.endpoint, body, headers, method="POST") + # req = Request("http://192.168.x.x:12380/score", body, headers, method="POST") response = urlopen(req) result = response.read() - result_object = json.loads(result.decode("utf8")) + result_object = [json.loads(result.decode("utf8"))] print(json.dumps(result_object[0].get("boxes"), indent=4)) #TODO Transform into logging return { diff --git a/model/swin.py b/model/swin.py index 8058e58..8164d1a 100644 --- a/model/swin.py +++ b/model/swin.py @@ -56,7 +56,8 @@ async def request_inference_from_swin(model: namedtuple, previous_result: 'list[ model.deployment_platform: model.name } body = img - req = Request(model.endpoint, body, headers) + req = Request(model.endpoint, body, headers, method="POST") + # req = Request("http://192.168.x.x:12390/score", body, headers, method="POST") response = urlopen(req) result = response.read() results.append(json.loads(result.decode("utf8"))) diff --git a/requirements.txt b/requirements.txt index 488f660..82b218d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,16 @@ -numpy +nachet-datastore @git+https://github.com/ai-cfia/ailab-datastore.git@v1.0.0-nachet-datastore +numpy==1.26.4 azure-storage-blob azure-identity -quart +flask==3.0.3 +quart==0.19.6 quart-cors python-dotenv hypercorn -Pillow +Pillow==10.3.0 cryptography pyyaml -pydantic +pydantic==2.7.1 +pydantic-core==2.18.2 python-magic PyJWT -nachet-datastore @git+https://github.com/ai-cfia/ailab-datastore.git@v1.0.0-nachet-datastore