Skip to content

Commit

Permalink
add train type endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Klaassen committed May 17, 2024
1 parent 055266b commit e315c63
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
18 changes: 18 additions & 0 deletions aid/provide/ns_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class TrainRecord(BaseModel):
type: str


class TrainType(BaseModel):
id: int
type: str


class NSTrainProvider(BaseProvider):
def get_trains(self, start: datetime, end: datetime) -> list[TrainRecord]:
query = """
Expand All @@ -32,6 +37,19 @@ def get_trains(self, start: datetime, end: datetime) -> list[TrainRecord]:

return results

def get_train_types(self, start: datetime, end: datetime) -> list[TrainType]:
query = """
select id, type
from std.trains
where timestamp between %s and %s
order by id asc
"""
with self._pg_conn as conn:
with conn.cursor(row_factory=class_row(TrainType)) as cur:
results = cur.execute(query, (start, end)).fetchall()

return results

def get_current_count(self) -> int:
query = """
select count(*) as cnt
Expand Down
12 changes: 11 additions & 1 deletion aid/provide/routers/trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel

from aid.provide.dependencies import get_api_key
from aid.provide.ns_trains import TrainRecord, NSTrainProvider
from aid.provide.ns_trains import TrainRecord, NSTrainProvider, TrainType
from aid.provide.response import CSVResponse

router = APIRouter(prefix="/trains", tags=["trains"], dependencies=[Security(get_api_key)])
Expand Down Expand Up @@ -96,6 +96,16 @@ def get_pivoted_data(start: datetime | None = None, end: datetime | None = None)
return df.to_csv(index=False)


@router.get("/types")
def get_train_types(start: datetime | None = None, end: datetime | None = None) -> list[TrainType]:
"""
Get train types for the requested period.
"""
end = end or datetime.now()
start = start or end - timedelta(seconds=10)
return NSTrainProvider().get_train_types(start, end)


if __name__ == "__main__":
start = datetime.now() - timedelta(minutes=1)
print(get_pivoted_data(start=start))

0 comments on commit e315c63

Please sign in to comment.