From 564df3e24b6b88e55a9e227b23054dff3d6f673b Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Mon, 14 Oct 2024 14:51:41 +0300 Subject: [PATCH] feat(fal): better endpoint error --- projects/fal/src/fal/app.py | 27 ++++++++++++++++++++------- projects/fal/tests/test_apps.py | 14 +++++++------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/projects/fal/src/fal/app.py b/projects/fal/src/fal/app.py index eaab139d..e151a837 100644 --- a/projects/fal/src/fal/app.py +++ b/projects/fal/src/fal/app.py @@ -9,6 +9,7 @@ import time import typing from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass from typing import Any, Callable, ClassVar, Literal, TypeVar import httpx @@ -17,7 +18,7 @@ import fal.api from fal._serialization import include_modules_from from fal.api import RouteSignature -from fal.exceptions import RequestCancelledException +from fal.exceptions import FalServerlessException, RequestCancelledException from fal.logging import get_logger from fal.toolkit.file import get_lifecycle_preference from fal.toolkit.file.providers.fal import GLOBAL_LIFECYCLE_PREFERENCE @@ -76,6 +77,12 @@ def initialize_and_serve(): return fn +@dataclass +class AppClientError(FalServerlessException): + message: str + status_code: int + + class EndpointClient: def __init__(self, url, endpoint, signature, timeout: int | None = None): self.url = url @@ -88,17 +95,19 @@ def __init__(self, url, endpoint, signature, timeout: int | None = None): def __call__(self, data): with httpx.Client() as client: + url = self.url + self.signature.path resp = client.post( self.url + self.signature.path, json=data.dict() if hasattr(data, "dict") else dict(data), timeout=self.timeout, ) - try: - resp.raise_for_status() - except httpx.HTTPStatusError: + if not resp.is_success: # allow logs to be printed before raising the exception time.sleep(1) - raise + raise AppClientError( + f"Failed to POST {url}: {resp.status_code} {resp.text}", + status_code=resp.status_code, + ) resp_dict = resp.json() if not self.return_type: @@ -151,12 +160,16 @@ def _print_logs(): with httpx.Client() as client: retries = 100 for _ in range(retries): - resp = client.get(info.url + "/health", timeout=60) + url = info.url + "/health" + resp = client.get(url, timeout=60) if resp.is_success: break elif resp.status_code not in (500, 404): - resp.raise_for_status() + raise AppClientError( + f"Failed to GET {url}: {resp.status_code} {resp.text}", + status_code=resp.status_code, + ) time.sleep(0.1) client = cls(app_cls, info.url) diff --git a/projects/fal/tests/test_apps.py b/projects/fal/tests/test_apps.py index 3c541f72..114bdc71 100644 --- a/projects/fal/tests/test_apps.py +++ b/projects/fal/tests/test_apps.py @@ -12,7 +12,7 @@ import httpx import pytest from fal import apps -from fal.app import AppClient +from fal.app import AppClient, AppClientError from fal.cli.deploy import _get_user from fal.container import ContainerImage from fal.exceptions import AppException, FieldException, RequestCancelledException @@ -692,7 +692,7 @@ def test_workflows(test_app: str): def test_traceback_logs(test_exception_app: AppClient): date = datetime.utcnow().isoformat() - with pytest.raises(HTTPStatusError): + with pytest.raises(AppClientError): test_exception_app.fail({}) with httpx.Client( @@ -714,17 +714,17 @@ def test_traceback_logs(test_exception_app: AppClient): def test_app_exceptions(test_exception_app: AppClient): - with pytest.raises(HTTPStatusError) as app_exc: + with pytest.raises(AppClientError) as app_exc: test_exception_app.app_exception({}) - assert app_exc.value.response.status_code == 401 + assert app_exc.status_code == 401 - with pytest.raises(HTTPStatusError) as field_exc: + with pytest.raises(AppClientError) as field_exc: test_exception_app.field_exception({"lhs": 1, "rhs": "2"}) - assert field_exc.value.response.status_code == 422 + assert field_exc.status_code == 422 - with pytest.raises(HTTPStatusError) as cuda_exc: + with pytest.raises(AppClientError) as cuda_exc: test_exception_app.cuda_exception({}) assert cuda_exc.value.response.status_code == _CUDA_OOM_STATUS_CODE