Skip to content

Commit

Permalink
Make stop times functions not use day (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarrco authored Dec 9, 2023
2 parents 0a13b88 + 7b178c6 commit 4116ebf
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 51 deletions.
46 changes: 20 additions & 26 deletions server/base/source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from datetime import datetime, date, timedelta, time
from datetime import datetime, date, timedelta

from sqlalchemy import select, func, and_
from sqlalchemy.dialects.postgresql import insert
Expand Down Expand Up @@ -93,35 +93,32 @@ def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sourc
found = limit_hits if limit_hits else results['found']
return stations, found

def get_stop_times(self, stops_ids, line, start_time, day, offset: int | tuple[int], count=False,
limit: int | None = None, direction=1) -> list[StopTime] | list[str]:
day_start = datetime.combine(day, time(0))
def get_stop_times(self, stops_ids, line, start_dt: datetime, offset: int | tuple[int], count=False,
limit: int | None = None, direction=1, end_dt: datetime = None) -> list[StopTime] | list[str]:

if limit is None:
limit = self.LIMIT

if start_time == '':
start_dt = day_start
else:
start_dt = datetime.combine(day, start_time)

end_dt = day_start + timedelta(days=1) if direction == 1 else day_start

stops_ids = stops_ids.split(',')

if count:
stmt = select(StopTime.route_name)
else:
stmt = select(StopTime)

day = start_dt.date()
day_minus_one = day - timedelta(days=1)

stmt = stmt.filter(StopTime.orig_dep_date.between(day_minus_one, day), StopTime.stop_id.in_(stops_ids))

if direction == 1:
stmt = stmt.filter(StopTime.sched_dep_dt >= start_dt, StopTime.sched_dep_dt < end_dt)
stmt = stmt.filter(StopTime.sched_dep_dt >= start_dt)
if end_dt:
stmt = stmt.filter(StopTime.sched_dep_dt < end_dt)
else:
stmt = stmt.filter(StopTime.sched_dep_dt <= start_dt, StopTime.sched_dep_dt >= end_dt)
stmt = stmt.filter(StopTime.sched_dep_dt <= start_dt)
if end_dt:
stmt = stmt.filter(StopTime.sched_dep_dt >= end_dt)

# if we are offsetting by ids of stop times (tuple[int])
if isinstance(offset, tuple):
Expand Down Expand Up @@ -156,22 +153,14 @@ def get_stop_times(self, stops_ids, line, start_time, day, offset: int | tuple[i

return stop_times

def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start_time,
offset: int | tuple[int], day,
count=False, limit: int | None = None, direction=1) \
def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start_dt: datetime,
offset: int | tuple[int],
count=False, limit: int | None = None, direction=1, end_dt: datetime = None) \
-> list[tuple[StopTime, StopTime]] | list[str]:
day_start = datetime.combine(day, time(0))

if limit is None:
limit = self.LIMIT

if start_time == '':
start_dt = day_start
else:
start_dt = datetime.combine(day, start_time)

end_dt = day_start + timedelta(days=1) if direction == 1 else day_start

dep_stops_ids = dep_stops_ids.split(',')
arr_stops_ids = arr_stops_ids.split(',')

Expand All @@ -184,6 +173,7 @@ def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start
else:
stmt = select(d_stop_times, a_stop_times)

day = start_dt.date()
day_minus_one = day - timedelta(days=1)

stmt = stmt \
Expand All @@ -197,9 +187,13 @@ def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start
d_stop_times.sched_dep_dt < a_stop_times.sched_arr_dt)

if direction == 1:
stmt = stmt.filter(d_stop_times.sched_dep_dt >= start_dt, d_stop_times.sched_dep_dt < end_dt)
stmt = stmt.filter(d_stop_times.sched_dep_dt >= start_dt)
if end_dt:
stmt = stmt.filter(d_stop_times.sched_dep_dt < end_dt)
else:
stmt = stmt.filter(d_stop_times.sched_dep_dt <= start_dt, d_stop_times.sched_dep_dt >= end_dt)
stmt = stmt.filter(d_stop_times.sched_dep_dt <= start_dt)
if end_dt:
stmt = stmt.filter(d_stop_times.sched_dep_dt >= end_dt)

# if we are offsetting by ids of stop times (tuple[int])
if isinstance(offset, tuple):
Expand Down
31 changes: 17 additions & 14 deletions server/routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import date, datetime
from datetime import datetime

from sqlalchemy import text
from starlette.requests import Request
Expand Down Expand Up @@ -47,12 +47,16 @@ async def get_stop_times(request: Request) -> Response:
source_name = request.query_params.get('source')
if not source_name:
return Response(status_code=400, content='Missing source')
day = request.query_params.get('day')
if not day:
return Response(status_code=400, content='Missing day')
start_time = request.query_params.get('start_time', '')
if start_time != '':
start_time = datetime.strptime(start_time, '%H:%M').time()

start_dt_str = request.query_params.get('start_dt')
if not start_dt_str:
return Response(status_code=400, content='Missing start_dt')
start_dt = datetime.fromisoformat(start_dt_str)

end_dt_str = request.query_params.get('end_dt')
end_dt = None
if end_dt_str:
end_dt = datetime.fromisoformat(end_dt_str)

str_offset = request.query_params.get('offset_by_ids', '')

Expand All @@ -66,19 +70,18 @@ async def get_stop_times(request: Request) -> Response:
if limit > 15:
limit = 15

day = date.fromisoformat(day)

source: Source = sources[source_name]

if arr_stops_ids:
stop_times: list[tuple[StopTime, StopTime]] = source.get_stop_times_between_stops(dep_stops_ids, arr_stops_ids,
'', start_time,
offset, day, limit=limit,
direction=direction)
'', start_dt,
offset, limit=limit,
direction=direction,
end_dt=end_dt)
return JSONResponse([[stop_time[0].as_dict(), stop_time[1].as_dict()] for stop_time in stop_times])
else:
stop_times: list[StopTime] = source.get_stop_times(dep_stops_ids, '', start_time, day, offset, limit=limit,
direction=direction)
stop_times: list[StopTime] = source.get_stop_times(dep_stops_ids, '', start_dt, offset, limit=limit,
direction=direction, end_dt=end_dt)
return JSONResponse([[stop_time.as_dict()] for stop_time in stop_times])


Expand Down
30 changes: 19 additions & 11 deletions tgbot/stop_times_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,26 @@ def get_times(self, db_file: Source) -> list[Liner]:

start_time = self.start_time

if start_time != '' and self.first_time:
start_time = datetime.combine(self.day, start_time) - timedelta(minutes=self.source.MINUTES_TOLERANCE)
start_time = start_time.time()
if start_time == '':
start_dt = datetime.combine(self.day, time())
else:
start_dt = datetime.combine(self.day, start_time)

if self.first_time:
start_dt -= timedelta(minutes=self.source.MINUTES_TOLERANCE)
if start_dt.day < self.day.day:
start_dt = datetime.combine(self.day, time())

end_dt = datetime.combine(self.day + timedelta(days=1), time())

if self.arr_stop_ids:
arr_stop = Station(name=self.arr_cluster_name, ids=self.arr_stop_ids)
stop_times_tuples: list[tuple[StopTime, StopTime]] = db_file.get_stop_times_between_stops(dep_stop.ids,
arr_stop.ids,
self.line,
start_time,
start_dt,
self.offset_times,
self.day)
end_dt=end_dt)
results: list[Direction] = []
for stop_time_tuple in stop_times_tuples:
dep_stop_time, arr_stop_time = stop_time_tuple
Expand All @@ -100,16 +108,16 @@ def get_times(self, db_file: Source) -> list[Liner]:
results.append(Direction([Route(dep_named_stop_time, arr_named_stop_time)]))
if self.lines is None:
self.lines: list[str] = db_file.get_stop_times_between_stops(dep_stop.ids, arr_stop.ids, self.line,
start_time, self.offset_times,
self.day, count=True)
start_dt, self.offset_times, count=True,
end_dt=end_dt)
return results

stop_times: list[StopTime] = db_file.get_stop_times(dep_stop.ids, self.line, start_time, self.day,
self.offset_times)
stop_times: list[StopTime] = db_file.get_stop_times(dep_stop.ids, self.line, start_dt,
self.offset_times, end_dt=end_dt)
results: list[NamedStopTime] = [NamedStopTime(stop_time, self.dep_cluster_name) for stop_time in stop_times]
if self.lines is None:
self.lines: list[str] = db_file.get_stop_times(dep_stop.ids, self.line, start_time, self.day,
self.offset_times, count=True)
self.lines: list[str] = db_file.get_stop_times(dep_stop.ids, self.line, start_dt,
self.offset_times, count=True, end_dt=end_dt)

return results

Expand Down

0 comments on commit 4116ebf

Please sign in to comment.