diff --git a/server/base/source.py b/server/base/source.py index 88a04a3..a3f5e51 100644 --- a/server/base/source.py +++ b/server/base/source.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, date, timedelta, time +from datetime import datetime, date, timedelta from sqlalchemy import select, func, and_ from sqlalchemy.dialects.postgresql import insert @@ -93,20 +93,12 @@ def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sourc found = limit_hits if limit_hits else results['found'] return stations, found - def get_stop_times(self, stops_ids, line, start_time, day, offset: int | tuple[int], count=False, - limit: int | None = None, direction=1) -> list[StopTime] | list[str]: - day_start = datetime.combine(day, time(0)) + def get_stop_times(self, stops_ids, line, start_dt: datetime, offset: int | tuple[int], count=False, + limit: int | None = None, direction=1, end_dt: datetime = None) -> list[StopTime] | list[str]: if limit is None: limit = self.LIMIT - if start_time == '': - start_dt = day_start - else: - start_dt = datetime.combine(day, start_time) - - end_dt = day_start + timedelta(days=1) if direction == 1 else day_start - stops_ids = stops_ids.split(',') if count: @@ -114,14 +106,19 @@ def get_stop_times(self, stops_ids, line, start_time, day, offset: int | tuple[i else: stmt = select(StopTime) + day = start_dt.date() day_minus_one = day - timedelta(days=1) stmt = stmt.filter(StopTime.orig_dep_date.between(day_minus_one, day), StopTime.stop_id.in_(stops_ids)) if direction == 1: - stmt = stmt.filter(StopTime.sched_dep_dt >= start_dt, StopTime.sched_dep_dt < end_dt) + stmt = stmt.filter(StopTime.sched_dep_dt >= start_dt) + if end_dt: + stmt = stmt.filter(StopTime.sched_dep_dt < end_dt) else: - stmt = stmt.filter(StopTime.sched_dep_dt <= start_dt, StopTime.sched_dep_dt >= end_dt) + stmt = stmt.filter(StopTime.sched_dep_dt <= start_dt) + if end_dt: + stmt = stmt.filter(StopTime.sched_dep_dt >= end_dt) # if we are offsetting by ids of stop times (tuple[int]) if isinstance(offset, tuple): @@ -156,22 +153,14 @@ def get_stop_times(self, stops_ids, line, start_time, day, offset: int | tuple[i return stop_times - def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start_time, - offset: int | tuple[int], day, - count=False, limit: int | None = None, direction=1) \ + def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start_dt: datetime, + offset: int | tuple[int], + count=False, limit: int | None = None, direction=1, end_dt: datetime = None) \ -> list[tuple[StopTime, StopTime]] | list[str]: - day_start = datetime.combine(day, time(0)) if limit is None: limit = self.LIMIT - if start_time == '': - start_dt = day_start - else: - start_dt = datetime.combine(day, start_time) - - end_dt = day_start + timedelta(days=1) if direction == 1 else day_start - dep_stops_ids = dep_stops_ids.split(',') arr_stops_ids = arr_stops_ids.split(',') @@ -184,6 +173,7 @@ def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start else: stmt = select(d_stop_times, a_stop_times) + day = start_dt.date() day_minus_one = day - timedelta(days=1) stmt = stmt \ @@ -197,9 +187,13 @@ def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start d_stop_times.sched_dep_dt < a_stop_times.sched_arr_dt) if direction == 1: - stmt = stmt.filter(d_stop_times.sched_dep_dt >= start_dt, d_stop_times.sched_dep_dt < end_dt) + stmt = stmt.filter(d_stop_times.sched_dep_dt >= start_dt) + if end_dt: + stmt = stmt.filter(d_stop_times.sched_dep_dt < end_dt) else: - stmt = stmt.filter(d_stop_times.sched_dep_dt <= start_dt, d_stop_times.sched_dep_dt >= end_dt) + stmt = stmt.filter(d_stop_times.sched_dep_dt <= start_dt) + if end_dt: + stmt = stmt.filter(d_stop_times.sched_dep_dt >= end_dt) # if we are offsetting by ids of stop times (tuple[int]) if isinstance(offset, tuple): diff --git a/server/routes.py b/server/routes.py index 6c338de..a9cf19e 100644 --- a/server/routes.py +++ b/server/routes.py @@ -1,4 +1,4 @@ -from datetime import date, datetime +from datetime import datetime from sqlalchemy import text from starlette.requests import Request @@ -47,12 +47,16 @@ async def get_stop_times(request: Request) -> Response: source_name = request.query_params.get('source') if not source_name: return Response(status_code=400, content='Missing source') - day = request.query_params.get('day') - if not day: - return Response(status_code=400, content='Missing day') - start_time = request.query_params.get('start_time', '') - if start_time != '': - start_time = datetime.strptime(start_time, '%H:%M').time() + + start_dt_str = request.query_params.get('start_dt') + if not start_dt_str: + return Response(status_code=400, content='Missing start_dt') + start_dt = datetime.fromisoformat(start_dt_str) + + end_dt_str = request.query_params.get('end_dt') + end_dt = None + if end_dt_str: + end_dt = datetime.fromisoformat(end_dt_str) str_offset = request.query_params.get('offset_by_ids', '') @@ -66,19 +70,18 @@ async def get_stop_times(request: Request) -> Response: if limit > 15: limit = 15 - day = date.fromisoformat(day) - source: Source = sources[source_name] if arr_stops_ids: stop_times: list[tuple[StopTime, StopTime]] = source.get_stop_times_between_stops(dep_stops_ids, arr_stops_ids, - '', start_time, - offset, day, limit=limit, - direction=direction) + '', start_dt, + offset, limit=limit, + direction=direction, + end_dt=end_dt) return JSONResponse([[stop_time[0].as_dict(), stop_time[1].as_dict()] for stop_time in stop_times]) else: - stop_times: list[StopTime] = source.get_stop_times(dep_stops_ids, '', start_time, day, offset, limit=limit, - direction=direction) + stop_times: list[StopTime] = source.get_stop_times(dep_stops_ids, '', start_dt, offset, limit=limit, + direction=direction, end_dt=end_dt) return JSONResponse([[stop_time.as_dict()] for stop_time in stop_times]) diff --git a/tgbot/stop_times_filter.py b/tgbot/stop_times_filter.py index add9b9b..57d24f0 100644 --- a/tgbot/stop_times_filter.py +++ b/tgbot/stop_times_filter.py @@ -80,18 +80,26 @@ def get_times(self, db_file: Source) -> list[Liner]: start_time = self.start_time - if start_time != '' and self.first_time: - start_time = datetime.combine(self.day, start_time) - timedelta(minutes=self.source.MINUTES_TOLERANCE) - start_time = start_time.time() + if start_time == '': + start_dt = datetime.combine(self.day, time()) + else: + start_dt = datetime.combine(self.day, start_time) + + if self.first_time: + start_dt -= timedelta(minutes=self.source.MINUTES_TOLERANCE) + if start_dt.day < self.day.day: + start_dt = datetime.combine(self.day, time()) + + end_dt = datetime.combine(self.day + timedelta(days=1), time()) if self.arr_stop_ids: arr_stop = Station(name=self.arr_cluster_name, ids=self.arr_stop_ids) stop_times_tuples: list[tuple[StopTime, StopTime]] = db_file.get_stop_times_between_stops(dep_stop.ids, arr_stop.ids, self.line, - start_time, + start_dt, self.offset_times, - self.day) + end_dt=end_dt) results: list[Direction] = [] for stop_time_tuple in stop_times_tuples: dep_stop_time, arr_stop_time = stop_time_tuple @@ -100,16 +108,16 @@ def get_times(self, db_file: Source) -> list[Liner]: results.append(Direction([Route(dep_named_stop_time, arr_named_stop_time)])) if self.lines is None: self.lines: list[str] = db_file.get_stop_times_between_stops(dep_stop.ids, arr_stop.ids, self.line, - start_time, self.offset_times, - self.day, count=True) + start_dt, self.offset_times, count=True, + end_dt=end_dt) return results - stop_times: list[StopTime] = db_file.get_stop_times(dep_stop.ids, self.line, start_time, self.day, - self.offset_times) + stop_times: list[StopTime] = db_file.get_stop_times(dep_stop.ids, self.line, start_dt, + self.offset_times, end_dt=end_dt) results: list[NamedStopTime] = [NamedStopTime(stop_time, self.dep_cluster_name) for stop_time in stop_times] if self.lines is None: - self.lines: list[str] = db_file.get_stop_times(dep_stop.ids, self.line, start_time, self.day, - self.offset_times, count=True) + self.lines: list[str] = db_file.get_stop_times(dep_stop.ids, self.line, start_dt, + self.offset_times, count=True, end_dt=end_dt) return results