diff --git a/server/base/source.py b/server/base/source.py index 783daca..88a04a3 100644 --- a/server/base/source.py +++ b/server/base/source.py @@ -4,7 +4,6 @@ from sqlalchemy import select, func, and_ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import aliased -from telegram.ext import ContextTypes from tgbot.formatting import Liner from .models import Station, Stop, StopTime @@ -94,8 +93,8 @@ def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sourc found = limit_hits if limit_hits else results['found'] return stations, found - def get_stop_times(self, stops_ids, line, start_time, day, - offset_times, count=False, limit: int | None = None) -> list[StopTime] | list[str]: + def get_stop_times(self, stops_ids, line, start_time, day, offset: int | tuple[int], count=False, + limit: int | None = None, direction=1) -> list[StopTime] | list[str]: day_start = datetime.combine(day, time(0)) if limit is None: @@ -104,9 +103,9 @@ def get_stop_times(self, stops_ids, line, start_time, day, if start_time == '': start_dt = day_start else: - start_dt = datetime.combine(day, start_time) - timedelta(minutes=self.MINUTES_TOLERANCE) + start_dt = datetime.combine(day, start_time) - end_dt = day_start + timedelta(days=1) + end_dt = day_start + timedelta(days=1) if direction == 1 else day_start stops_ids = stops_ids.split(',') @@ -117,15 +116,16 @@ def get_stop_times(self, stops_ids, line, start_time, day, day_minus_one = day - timedelta(days=1) - stmt = stmt \ - .filter( - and_( - StopTime.orig_dep_date.between(day_minus_one, day), - StopTime.stop_id.in_(stops_ids), - StopTime.sched_dep_dt >= start_dt, - StopTime.sched_dep_dt < end_dt - ) - ) + stmt = stmt.filter(StopTime.orig_dep_date.between(day_minus_one, day), StopTime.stop_id.in_(stops_ids)) + + if direction == 1: + stmt = stmt.filter(StopTime.sched_dep_dt >= start_dt, StopTime.sched_dep_dt < end_dt) + else: + stmt = stmt.filter(StopTime.sched_dep_dt <= start_dt, StopTime.sched_dep_dt >= end_dt) + + # if we are offsetting by ids of stop times (tuple[int]) + if isinstance(offset, tuple): + stmt = stmt.filter(StopTime.id.notin_(offset)) if line != '': stmt = stmt.filter(StopTime.route_name == line) @@ -136,25 +136,41 @@ def get_stop_times(self, stops_ids, line, start_time, day, .order_by(func.count(StopTime.route_name).desc()) stop_times = self.session.execute(stmt).all() else: - stmt = stmt.order_by(StopTime.sched_dep_dt).limit(limit).offset(offset_times) + if direction == 1: + stmt = stmt.order_by(StopTime.sched_dep_dt.asc()) + else: + stmt = stmt.order_by(StopTime.sched_arr_dt.desc()) + + if isinstance(offset, int): + stmt = stmt.offset(offset) + + stmt = stmt.limit(limit) + stop_times = self.session.scalars(stmt).all() + if direction == -1: + stop_times.reverse() + if count: return [train.route_name for train in stop_times] return stop_times def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start_time, - offset_times, day, context: ContextTypes.DEFAULT_TYPE | None = None, count=False) \ + offset: int | tuple[int], day, + count=False, limit: int | None = None, direction=1) \ -> list[tuple[StopTime, StopTime]] | list[str]: day_start = datetime.combine(day, time(0)) + if limit is None: + limit = self.LIMIT + if start_time == '': start_dt = day_start else: - start_dt = datetime.combine(day, start_time) - timedelta(minutes=self.MINUTES_TOLERANCE) + start_dt = datetime.combine(day, start_time) - end_dt = day_start + timedelta(days=1) + end_dt = day_start + timedelta(days=1) if direction == 1 else day_start dep_stops_ids = dep_stops_ids.split(',') arr_stops_ids = arr_stops_ids.split(',') @@ -174,17 +190,20 @@ def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start .select_from(d_stop_times) \ .join(a_stop_times, and_(d_stop_times.number == a_stop_times.number, d_stop_times.orig_dep_date == a_stop_times.orig_dep_date, - d_stop_times.source == a_stop_times.source)) \ - .filter( - and_( - d_stop_times.orig_dep_date.between(day_minus_one, day), - d_stop_times.stop_id.in_(dep_stops_ids), - d_stop_times.sched_dep_dt >= start_dt, - d_stop_times.sched_dep_dt < end_dt, - d_stop_times.sched_dep_dt < a_stop_times.sched_arr_dt, - a_stop_times.stop_id.in_(arr_stops_ids) - ) - ) + d_stop_times.source == a_stop_times.source)) + + stmt = stmt.filter(d_stop_times.orig_dep_date.between(day_minus_one, day), + d_stop_times.stop_id.in_(dep_stops_ids), a_stop_times.stop_id.in_(arr_stops_ids), + d_stop_times.sched_dep_dt < a_stop_times.sched_arr_dt) + + if direction == 1: + stmt = stmt.filter(d_stop_times.sched_dep_dt >= start_dt, d_stop_times.sched_dep_dt < end_dt) + else: + stmt = stmt.filter(d_stop_times.sched_dep_dt <= start_dt, d_stop_times.sched_dep_dt >= end_dt) + + # if we are offsetting by ids of stop times (tuple[int]) + if isinstance(offset, tuple): + stmt = stmt.filter(d_stop_times.id.notin_(offset)) if line != '': stmt = stmt.filter(d_stop_times.route_name == line) @@ -193,12 +212,21 @@ def get_stop_times_between_stops(self, dep_stops_ids, arr_stops_ids, line, start stmt = stmt.group_by(d_stop_times.route_name).order_by( func.count(d_stop_times.route_name).desc()) else: - stmt = stmt.order_by( - d_stop_times.sched_dep_dt - ).limit(self.LIMIT).offset(offset_times) + if direction == 1: + stmt = stmt.order_by(d_stop_times.sched_dep_dt.asc()) + else: + stmt = stmt.order_by(d_stop_times.sched_arr_dt.desc()) + + if isinstance(offset, int): + stmt = stmt.offset(offset) + + stmt = stmt.limit(limit) raw_stop_times = self.session.execute(stmt).all() + if direction == -1: + raw_stop_times.reverse() + if count: return [train.route_name for train in raw_stop_times] diff --git a/server/routes.py b/server/routes.py index 38d7760..6c338de 100644 --- a/server/routes.py +++ b/server/routes.py @@ -42,24 +42,44 @@ async def get_stop_times(request: Request) -> Response: dep_stops_ids = request.query_params.get('dep_stops_ids') if not dep_stops_ids: return Response(status_code=400, content='Missing dep_stops_ids') + arr_stops_ids = request.query_params.get('arr_stops_ids') + direction = int(request.query_params.get('direction', 1)) source_name = request.query_params.get('source') if not source_name: return Response(status_code=400, content='Missing source') day = request.query_params.get('day') if not day: return Response(status_code=400, content='Missing day') - offset = int(request.query_params.get('offset', 0)) + start_time = request.query_params.get('start_time', '') + if start_time != '': + start_time = datetime.strptime(start_time, '%H:%M').time() + + str_offset = request.query_params.get('offset_by_ids', '') + + if str_offset == '': + offset: int = 0 + else: + offset: tuple[int] = tuple(map(int, str_offset.split(','))) + limit = int(request.query_params.get('limit', 10)) - day = date.fromisoformat(day) + if limit > 15: + limit = 15 - # start time can only be either now, if today, or empty (start of the day) for next days - start_time = datetime.now().time() if day == date.today() else '' + day = date.fromisoformat(day) source: Source = sources[source_name] - stop_times: list[StopTime] = source.get_stop_times(dep_stops_ids, '', start_time, day, offset, limit=limit) - return JSONResponse([stop_time.as_dict() for stop_time in stop_times]) + if arr_stops_ids: + stop_times: list[tuple[StopTime, StopTime]] = source.get_stop_times_between_stops(dep_stops_ids, arr_stops_ids, + '', start_time, + offset, day, limit=limit, + direction=direction) + return JSONResponse([[stop_time[0].as_dict(), stop_time[1].as_dict()] for stop_time in stop_times]) + else: + stop_times: list[StopTime] = source.get_stop_times(dep_stops_ids, '', start_time, day, offset, limit=limit, + direction=direction) + return JSONResponse([[stop_time.as_dict()] for stop_time in stop_times]) diff --git a/tgbot/stop_times_filter.py b/tgbot/stop_times_filter.py index 9caae98..add9b9b 100644 --- a/tgbot/stop_times_filter.py +++ b/tgbot/stop_times_filter.py @@ -76,19 +76,22 @@ def inline_button(self, text: str, **new_params): return InlineKeyboardButton(text, callback_data=self.query_data(**new_params)) def get_times(self, db_file: Source) -> list[Liner]: - day, dep_stop_ids, line, start_time = self.day, self.dep_stop_ids, self.line, \ - self.start_time + dep_stop = Station(name=self.dep_cluster_name, ids=self.dep_stop_ids) - dep_stop = Station(name=self.dep_cluster_name, ids=dep_stop_ids) + start_time = self.start_time + + if start_time != '' and self.first_time: + start_time = datetime.combine(self.day, start_time) - timedelta(minutes=self.source.MINUTES_TOLERANCE) + start_time = start_time.time() if self.arr_stop_ids: arr_stop = Station(name=self.arr_cluster_name, ids=self.arr_stop_ids) stop_times_tuples: list[tuple[StopTime, StopTime]] = db_file.get_stop_times_between_stops(dep_stop.ids, arr_stop.ids, - line, start_time, + self.line, + start_time, self.offset_times, - day, - context=self.context) + self.day) results: list[Direction] = [] for stop_time_tuple in stop_times_tuples: dep_stop_time, arr_stop_time = stop_time_tuple @@ -96,16 +99,17 @@ def get_times(self, db_file: Source) -> list[Liner]: arr_named_stop_time = NamedStopTime(arr_stop_time, self.arr_cluster_name) results.append(Direction([Route(dep_named_stop_time, arr_named_stop_time)])) if self.lines is None: - self.lines: list[str] = db_file.get_stop_times_between_stops(dep_stop.ids, arr_stop.ids, - line, start_time, self.offset_times, day, - context=self.context, count=True) + self.lines: list[str] = db_file.get_stop_times_between_stops(dep_stop.ids, arr_stop.ids, self.line, + start_time, self.offset_times, + self.day, count=True) return results - stop_times: list[StopTime] = db_file.get_stop_times(dep_stop.ids, line, start_time, day, self.offset_times) + stop_times: list[StopTime] = db_file.get_stop_times(dep_stop.ids, self.line, start_time, self.day, + self.offset_times) results: list[NamedStopTime] = [NamedStopTime(stop_time, self.dep_cluster_name) for stop_time in stop_times] if self.lines is None: - self.lines: list[str] = db_file.get_stop_times(dep_stop.ids, line, start_time, day, self.offset_times, - count=True) + self.lines: list[str] = db_file.get_stop_times(dep_stop.ids, self.line, start_time, self.day, + self.offset_times, count=True) return results