From c7f54b3e7ddb4bd61e386a07b32af725d4be415f Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Wed, 11 Oct 2023 00:30:55 +0200 Subject: [PATCH] Update show line to new system --- MuoVErsi/handlers.py | 6 +++--- MuoVErsi/sources/GTFS/source.py | 26 +++++++++++++------------- MuoVErsi/sources/base.py | 5 +++-- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/MuoVErsi/handlers.py b/MuoVErsi/handlers.py index df18561..c3d10f4 100644 --- a/MuoVErsi/handlers.py +++ b/MuoVErsi/handlers.py @@ -473,12 +473,12 @@ async def search_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int _ = trans.gettext try: - lines = db_file.search_lines(update.message.text, context=context) + lines = db_file.search_lines(update.message.text) except NotImplementedError: await update.message.reply_text(_('not_implemented'), disable_notification=True) return ConversationHandler.END - keyboard = [[InlineKeyboardButton(line[2], callback_data=f'L{line[0]}/{line[1]}-{line[3]}')] for line in lines] + keyboard = [[InlineKeyboardButton(line[2], callback_data=f'L{line[0]}/{line[1]}')] for line in lines] inline_markup = InlineKeyboardMarkup(keyboard) await update.message.reply_text(_('choose_line'), reply_markup=inline_markup, disable_notification=True) @@ -504,7 +504,7 @@ async def show_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: inline_buttons = [] for stop in stops: - station = stop.station + station = stop.station.station stop_id = station.id stop_name = station.name inline_buttons.append([InlineKeyboardButton(stop_name, callback_data=f'S{stop_id}/{line}')]) diff --git a/MuoVErsi/sources/GTFS/source.py b/MuoVErsi/sources/GTFS/source.py index c7a7411..4d3682d 100644 --- a/MuoVErsi/sources/GTFS/source.py +++ b/MuoVErsi/sources/GTFS/source.py @@ -19,6 +19,8 @@ from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster from .models import CStop +from sqlalchemy import or_, select, func + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -282,21 +284,19 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim return stop_times - def search_lines(self, name, context: ContextTypes.DEFAULT_TYPE | None = None): + def search_lines(self, name): today = date.today() - service_ids = self.get_active_service_ids(today) + from MuoVErsi.sources.base import Trip + trips = self.session.execute( + select(func.max(Trip.number), Trip.dest_text)\ + .filter(Trip.orig_dep_date == today)\ + .filter(Trip.route_name == name)\ + .group_by(Trip.dest_text)\ + .order_by(func.count(Trip.id).desc()))\ + .all() + + results = [(trip[0], name, trip[1]) for trip in trips] - cur = self.con.cursor() - query = """SELECT trips.trip_id, route_short_name, route_long_name, routes.route_id - FROM stop_times - INNER JOIN trips ON stop_times.trip_id = trips.trip_id - INNER JOIN routes ON trips.route_id = routes.route_id - WHERE route_short_name = ? - AND trips.service_id in ({seq}) - GROUP BY routes.route_id ORDER BY count(stop_times.id) DESC;""".format( - seq=','.join(['?'] * len(service_ids))) - - results = cur.execute(query, (name, *service_ids)).fetchall() return results def get_active_service_ids(self, day: date) -> tuple: diff --git a/MuoVErsi/sources/base.py b/MuoVErsi/sources/base.py index bb4dea6..0f06397 100644 --- a/MuoVErsi/sources/base.py +++ b/MuoVErsi/sources/base.py @@ -478,7 +478,7 @@ def get_stop_from_ref(self, ref) -> Station | None: else: return None - def search_lines(self, name, context: ContextTypes.DEFAULT_TYPE | None = None): + def search_lines(self, name): raise NotImplementedError def get_source_stations(self) -> list[Station]: @@ -512,6 +512,7 @@ def upload_trip_stop_time_to_postgres(self, stop_time: TripStopTime): self.session.commit() def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: + trip_id = int(trip_id) query = select(StopTime, Trip, Stop) \ .join(StopTime.trip) \ .join(StopTime.stop) \ @@ -522,7 +523,7 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: )) \ .order_by(StopTime.sched_dep_dt) - results = self.session.execute(query) + results = self.session.execute(query).all() stop_times = [] for result in results: