diff --git a/aid/provide/routers/trains.py b/aid/provide/routers/trains.py index 750ddc4..8b192e7 100644 --- a/aid/provide/routers/trains.py +++ b/aid/provide/routers/trains.py @@ -5,9 +5,10 @@ import pandas as pd from fastapi import APIRouter, Security, status from pydantic import BaseModel +from starlette.responses import Response from aid.provide.dependencies import get_api_key -from aid.provide.ns_trains import TrainRecord, NSTrainProvider, TrainType +from aid.provide.ns_trains import TrainRecord, NSTrainProvider from aid.provide.response import CSVResponse router = APIRouter(prefix="/trains", tags=["trains"], dependencies=[Security(get_api_key)]) @@ -67,7 +68,7 @@ def get_records_keyed_by_timestamp(start: datetime | None = None, end: datetime @router.get("/pivoted", response_class=CSVResponse) -def get_pivoted_data(start: datetime | None = None, end: datetime | None = None): +def get_pivoted_data(start: datetime | None = None, end: datetime | None = None) -> str: """ Get train locations pivoted for use in TouchDesigner. Start and end parameters work the same as in `/records`. @@ -75,7 +76,7 @@ def get_pivoted_data(start: datetime | None = None, end: datetime | None = None) records = _records(start, end) if not records: print("No records") - return status.HTTP_204_NO_CONTENT + return Response(status_code=status.HTTP_204_NO_CONTENT) # type: ignore df = pd.DataFrame.from_records([rec.model_dump() for rec in records]) @@ -96,14 +97,17 @@ def get_pivoted_data(start: datetime | None = None, end: datetime | None = None) return df.to_csv(index=False) -@router.get("/types") -def get_train_types(start: datetime | None = None, end: datetime | None = None) -> list[TrainType]: +@router.get("/types", response_class=CSVResponse) +def get_train_types(start: datetime | None = None, end: datetime | None = None) -> str: """ Get train types for the requested period. """ end = end or datetime.now() start = start or end - timedelta(seconds=10) - return NSTrainProvider().get_train_types(start, end) + records = NSTrainProvider().get_train_types(start, end) + + df = pd.DataFrame.from_records([rec.model_dump() for rec in records]) + return df.to_csv(index=False) if __name__ == "__main__":