Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
robinklaassen committed May 29, 2024
2 parents e19e890 + a0b813d commit a2b5f33
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 6 deletions.
3 changes: 2 additions & 1 deletion aid/collect/ns_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aid.logger import logger

NS_VIRTUAL_TRAIN_URL = "https://gateway.apiportal.ns.nl/virtual-train-api/api/vehicle"
REQUEST_TIMEOUT: int = 5


class TrainModel(BaseModel):
Expand Down Expand Up @@ -52,7 +53,7 @@ def _get_trains(self) -> TrainResponse | None:
response = requests.get(
NS_VIRTUAL_TRAIN_URL,
headers={"Ocp-Apim-Subscription-Key": self._api_key},
timeout=10,
timeout=REQUEST_TIMEOUT,
)

if not response.ok:
Expand Down
21 changes: 20 additions & 1 deletion aid/provide/ns_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ class TrainRecord(BaseModel):
speed: float
direction: float
accuracy: float
type: str


class TrainType(BaseModel):
id: int
type: str


class NSTrainProvider(BaseProvider):
def get_trains(self, start: datetime, end: datetime) -> list[TrainRecord]:
query = """
select timestamp, id, round(x) as x, round(y) as y, speed, direction, accuracy
select timestamp, id, round(x) as x, round(y) as y, speed, direction, accuracy, type
from std.trains
where timestamp between %s and %s
order by timestamp asc, id asc
Expand All @@ -31,6 +37,19 @@ def get_trains(self, start: datetime, end: datetime) -> list[TrainRecord]:

return results

def get_train_types(self, start: datetime, end: datetime) -> list[TrainType]:
query = """
select distinct id, type
from std.trains
where timestamp between %s and %s
order by id asc
"""
with self._pg_conn as conn:
with conn.cursor(row_factory=class_row(TrainType)) as cur:
results = cur.execute(query, (start, end)).fetchall()

return results

def get_current_count(self) -> int:
query = """
select count(*) as cnt
Expand Down
20 changes: 17 additions & 3 deletions aid/provide/routers/trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
from fastapi import APIRouter, Security, status
from pydantic import BaseModel
from starlette.responses import Response

from aid.provide.dependencies import get_api_key
from aid.provide.ns_trains import TrainRecord, NSTrainProvider
Expand Down Expand Up @@ -67,15 +68,15 @@ def get_records_keyed_by_timestamp(start: datetime | None = None, end: datetime


@router.get("/pivoted", response_class=CSVResponse)
def get_pivoted_data(start: datetime | None = None, end: datetime | None = None):
def get_pivoted_data(start: datetime | None = None, end: datetime | None = None) -> str:
"""
Get train locations pivoted for use in TouchDesigner.
Start and end parameters work the same as in `/records`.
"""
records = _records(start, end)
if not records:
print("No records")
return status.HTTP_204_NO_CONTENT
return Response(status_code=status.HTTP_204_NO_CONTENT) # type: ignore

df = pd.DataFrame.from_records([rec.model_dump() for rec in records])

Expand All @@ -88,14 +89,27 @@ def get_pivoted_data(start: datetime | None = None, end: datetime | None = None)
df = df.round(5)

# Pivot to requested format
df = df.melt(id_vars=["timestamp", "id"], value_vars=["x", "y", "speed"], var_name="var")
df = df.melt(id_vars=["timestamp", "id"], value_vars=["x", "y", "speed", "type"], var_name="var")
df = df.pivot(columns="id", index=["timestamp", "var"], values="value")
df = df.reset_index()
df["timestamp"] = df["timestamp"].dt.strftime("%H:%M:%S")

return df.to_csv(index=False)


@router.get("/types", response_class=CSVResponse)
def get_train_types(start: datetime | None = None, end: datetime | None = None) -> str:
"""
Get train types for the requested period.
"""
end = end or datetime.now()
start = start or end - timedelta(seconds=10)
records = NSTrainProvider().get_train_types(start, end)

df = pd.DataFrame.from_records([rec.model_dump() for rec in records])
return df.to_csv(index=False)


if __name__ == "__main__":
start = datetime.now() - timedelta(minutes=1)
print(get_pivoted_data(start=start))
2 changes: 1 addition & 1 deletion ddl/raw/ns_trains.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ create table if not exists ns_trains (

select addgeometrycolumn('ns_trains', 'location', 4326, 'POINT', 2);

select create_hypertable('ns_trains', 'timestamp');
select create_hypertable('ns_trains', by_range('timestamp', INTERVAL '1 day'));
create index on ns_trains(rit_id, timestamp desc);

0 comments on commit a2b5f33

Please sign in to comment.