Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
martinhoyer committed Dec 14, 2024
1 parent 817fb25 commit 37f0c72
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 44 deletions.
175 changes: 144 additions & 31 deletions src/tmt_web/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import platform
Expand All @@ -6,13 +7,13 @@
from typing import Annotated, Literal

from celery.result import AsyncResult
from fastapi import FastAPI, Request, status
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.params import Query
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse
from pydantic import BaseModel
from tmt import Logger
from tmt import __version__ as tmt_version
from tmt.utils import GeneralError
from tmt.utils import GeneralError, dict_to_yaml

from tmt_web import service, settings
from tmt_web.generators import html_generator
Expand All @@ -39,23 +40,59 @@ class TaskOut(BaseModel):
status_callback_url: str | None = None


class VersionInfo(BaseModel):
"""Version information model."""

api: str
python: str
tmt: str


class DependencyStatus(BaseModel):
"""Dependency status model."""

celery: str
redis: str


class SystemInfo(BaseModel):
"""System information model."""

platform: str
hostname: str
python_implementation: str


class HealthStatus(BaseModel):
"""Health check response model."""

status: str
timestamp: datetime
uptime_seconds: float
version: dict[str, str]
dependencies: dict[str, str]
system: dict[str, str]
version: VersionInfo
dependencies: DependencyStatus
system: SystemInfo


@app.exception_handler(GeneralError)
async def general_exception_handler(request: Request, exc: GeneralError):
"""Global exception handler for all tmt errors."""
logger.fail(str(exc))

# Map specific error messages to appropriate status codes
if "not found" in str(exc).lower():
status_code = status.HTTP_404_NOT_FOUND
elif any(msg in str(exc).lower() for msg in [
"must be provided together",
"missing required",
"invalid combination",
]):
status_code = status.HTTP_400_BAD_REQUEST
else:
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
status_code=status_code,
content={"detail": str(exc)},
)

Expand Down Expand Up @@ -152,13 +189,22 @@ def root(
logger.debug("Validating request parameters")
if (test_url is None and test_name is not None) or (test_url is not None and test_name is None):
logger.fail("Both test-url and test-name must be provided together")
raise GeneralError("Both test-url and test-name must be provided together")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Both test-url and test-name must be provided together",
)
if (plan_url is None and plan_name is not None) or (plan_url is not None and plan_name is None):
logger.fail("Both plan-url and plan-name must be provided together")
raise GeneralError("Both plan-url and plan-name must be provided together")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Both plan-url and plan-name must be provided together",
)
if plan_url is None and plan_name is None and test_url is None and test_name is None:
logger.fail("At least one of test or plan parameters must be provided")
raise GeneralError("At least one of test or plan parameters must be provided")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one of test or plan parameters must be provided",
)

service_args = {
"test_url": test_url,
Expand Down Expand Up @@ -194,10 +240,10 @@ def root(
return HTMLResponse(
content=html_generator.generate_status_callback(r, status_callback_url, logger),
)
if out_format == "yaml":
task_out = _to_task_out(r)
return PlainTextResponse(content=task_out.model_dump_json())
return _to_task_out(r)

# For both JSON and YAML formats, return JSON initially with appropriate callback URL
task_out = _to_task_out(r, out_format)
return JSONResponse(content=task_out.model_dump())


@app.get("/status", response_model=TaskOut)
Expand All @@ -212,12 +258,68 @@ def get_task_status(task_id: Annotated[str | None,
logger.debug(f"Getting task status for {task_id}")
if not task_id:
logger.fail("task-id is required")
raise GeneralError("task-id is required")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="task-id is required",
)

r = service.main.app.AsyncResult(task_id)

# Check for specific error conditions in the task result
if r.failed():
error_message = str(r.result)
if "not found" in error_message.lower():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=error_message,
)

return _to_task_out(r)


@app.get("/status/yaml", response_class=PlainTextResponse)
def get_task_status_yaml(task_id: Annotated[str | None,
Query(
alias="task-id",
title="Task ID",
description="ID of the task to check status for",
),
]) -> PlainTextResponse:
"""Get the status of an asynchronous task in YAML format."""
logger.debug(f"Getting YAML task status for {task_id}")
if not task_id:
logger.fail("task-id is required")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="task-id is required",
)

r = service.main.app.AsyncResult(task_id)

# Check for specific error conditions in the task result
if r.failed():
error_message = str(r.result)
if "not found" in error_message.lower():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=error_message,
)

# For YAML format, we only return the result in YAML format when the task is successful
if r.successful() and r.result:
try:
# The result might be a JSON string, so we need to parse it first
result_dict = json.loads(r.result)
return PlainTextResponse(dict_to_yaml(result_dict))
except json.JSONDecodeError:
# If it's not JSON, return the raw result
return PlainTextResponse(r.result)

# For pending or failed tasks, return a JSON response
task_out = _to_task_out(r, "yaml")
return JSONResponse(content=task_out.model_dump())


@app.get("/status/html", response_class=HTMLResponse)
def get_task_status_html(task_id: Annotated[str | None,
Query(
Expand All @@ -230,7 +332,10 @@ def get_task_status_html(task_id: Annotated[str | None,
logger.debug(f"Getting HTML task status for {task_id}")
if not task_id:
logger.fail("task-id is required")
raise GeneralError("task-id is required")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="task-id is required",
)

r = service.main.app.AsyncResult(task_id)
if r.successful() and r.result:
Expand All @@ -243,13 +348,21 @@ def get_task_status_html(task_id: Annotated[str | None,
)


def _to_task_out(r: AsyncResult) -> TaskOut: # type: ignore [type-arg]
def _to_task_out(r: AsyncResult, out_format: str = "json") -> TaskOut: # type: ignore [type-arg]
"""Convert a Celery AsyncResult to a TaskOut response model."""
# Use the appropriate status callback URL based on the requested format
status_callback_url = f"{settings.API_HOSTNAME}/status"
if out_format == "yaml":
status_callback_url += "/yaml"
elif out_format == "html":
status_callback_url += "/html"
status_callback_url += f"?task-id={r.task_id}"

return TaskOut(
id=r.task_id,
status=r.status,
result=r.traceback if r.failed() else r.result,
status_callback_url=f"{settings.API_HOSTNAME}/status?task-id={r.task_id}",
status_callback_url=status_callback_url,
)


Expand Down Expand Up @@ -286,18 +399,18 @@ def health_check() -> HealthStatus:
status="ok",
timestamp=datetime.now(UTC),
uptime_seconds=time.time() - START_TIME,
version={
"api": app.version,
"python": platform.python_version(),
"tmt": tmt_version,
},
dependencies={
"celery": celery_status,
"redis": redis_status,
},
system={
"platform": platform.platform(),
"hostname": platform.node(),
"python_implementation": platform.python_implementation(),
},
version=VersionInfo(
api=app.version,
python=platform.python_version(),
tmt=tmt_version,
),
dependencies=DependencyStatus(
celery=celery_status,
redis=redis_status,
),
system=SystemInfo(
platform=platform.platform(),
hostname=platform.node(),
python_implementation=platform.python_implementation(),
),
)
9 changes: 0 additions & 9 deletions src/tmt_web/generators/json_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ def from_fmf_id(cls, fmf_id: Any) -> "FmfIdModel":
ref=fmf_id.ref,
)

def model_dump(self, **kwargs: Any) -> dict[str, Any]:
"""Custom serialization for FmfId."""
return {
"name": self.name,
"url": self.url,
"path": self.path,
"ref": self.ref,
}


class ObjectModel(BaseModel):
"""Common structure for both Test and Plan objects in JSON output."""
Expand Down
Loading

0 comments on commit 37f0c72

Please sign in to comment.