diff --git a/MuoVErsi/handlers.py b/MuoVErsi/handlers.py index f076a6b..15a8a16 100644 --- a/MuoVErsi/handlers.py +++ b/MuoVErsi/handlers.py @@ -250,8 +250,13 @@ async def send_stop_times(_, lang, db_file: Source, stop_times_filter: StopTimes if context.user_data.get('day') != stop_times_filter.day.isoformat(): context.user_data['day'] = stop_times_filter.day.isoformat() + # add service_ids to Source instance, this way it can be accessed from get_stop_times + db_file.service_ids = context.bot_data.setdefault('service_ids', {}).setdefault(db_file.name, {}) + results = stop_times_filter.get_times(db_file) + context.bot_data['service_ids'][db_file.name] = db_file.service_ids + context.user_data['lines'] = stop_times_filter.lines text, reply_markup = stop_times_filter.format_times_text(results, _, lang) diff --git a/MuoVErsi/sources/GTFS/source.py b/MuoVErsi/sources/GTFS/source.py index c386eed..2d46f60 100644 --- a/MuoVErsi/sources/GTFS/source.py +++ b/MuoVErsi/sources/GTFS/source.py @@ -45,6 +45,7 @@ def __init__(self, transport_type, emoji, session, typesense, gtfs_version=None, super().__init__(transport_type[:3], emoji, session, typesense) self.transport_type = transport_type self.location = location + self.service_ids = {} if gtfs_version: self.gtfs_version = gtfs_version @@ -172,7 +173,7 @@ def upload_stops_clusters_to_db(self, force=False) -> bool: return True def get_stop_times(self, stop: Station, line, start_time, day, - offset_times, context: ContextTypes.DEFAULT_TYPE | None = None, count=False): + offset_times, count=False): cur = self.con.cursor() route_name, route_id = line.split('-') if '-' in line else (line, '') @@ -183,7 +184,7 @@ def get_stop_times(self, stop: Station, line, start_time, day, line = route_id route = 'AND r.route_id = ?' - today_service_ids = self.get_active_service_ids(day, context) + today_service_ids = self.get_active_service_ids(day) stop_ids = stop.ids.split(',') stop_ids = list(map(int, stop_ids)) @@ -198,7 +199,7 @@ def get_stop_times(self, stop: Station, line, start_time, day, or_other_service = '' yesterday_service_ids = [] if start_dt.hour < 6: - yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1), context) + yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1)) if yesterday_service_ids: or_other_service_ids = ','.join(['?'] * len(yesterday_service_ids)) or_other_service = f'OR (dep.departure_time >= ? AND t.service_id in ({or_other_service_ids}))' @@ -289,7 +290,7 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin line = route_id route = 'AND r.route_id = ?' - today_service_ids = self.get_active_service_ids(day, context) + today_service_ids = self.get_active_service_ids(day) dep_stop_ids = dep_stop.ids.split(',') dep_stop_ids = list(map(int, dep_stop_ids)) @@ -306,7 +307,7 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin or_other_service = '' yesterday_service_ids = [] if start_dt.hour < 6: - yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1), context) + yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1)) if yesterday_service_ids: or_other_service_ids = ','.join(['?'] * len(yesterday_service_ids)) or_other_service = f'OR (dep.departure_time >= ? AND t.service_id in ({or_other_service_ids}))' @@ -411,7 +412,7 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin def search_lines(self, name, context: ContextTypes.DEFAULT_TYPE | None = None): today = date.today() - service_ids = self.get_active_service_ids(today, context) + service_ids = self.get_active_service_ids(today) cur = self.con.cursor() query = """SELECT trips.trip_id, route_short_name, route_long_name, routes.route_id @@ -426,16 +427,14 @@ def search_lines(self, name, context: ContextTypes.DEFAULT_TYPE | None = None): results = cur.execute(query, (name, *service_ids)).fetchall() return results - def get_active_service_ids(self, day: date, context: ContextTypes.DEFAULT_TYPE | None = None) -> tuple: + def get_active_service_ids(self, day: date) -> tuple: today_ymd = day.strftime('%Y%m%d') - if context: - # access safely context.bot_data['service_ids'][self.name][today_ymd] - service_ids = context.bot_data.setdefault('service_ids', {}).setdefault(self.name, {}).setdefault(today_ymd, - None) - if service_ids: - logger.info(f'Using cached service_ids for {today_ymd}') - return service_ids + # access safely context.bot_data['service_ids'][self.name][today_ymd] + service_ids = self.service_ids.setdefault(today_ymd, None) + if service_ids: + logger.info(f'Using cached service_ids for {today_ymd}') + return service_ids weekday = day.strftime('%A').lower() @@ -461,9 +460,8 @@ def get_active_service_ids(self, day: date, context: ContextTypes.DEFAULT_TYPE | service_ids = tuple(service_ids) - if context: - context.bot_data.setdefault('service_ids', {}).setdefault(self.name, {})[today_ymd] = service_ids - logger.info(f'Cached service_ids for {today_ymd}') + self.service_ids[today_ymd] = service_ids + logger.info(f'Cached service_ids for {today_ymd}') return service_ids diff --git a/MuoVErsi/sources/base.py b/MuoVErsi/sources/base.py index 4446c98..b6c7ba9 100644 --- a/MuoVErsi/sources/base.py +++ b/MuoVErsi/sources/base.py @@ -221,7 +221,7 @@ def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sourc return stations, found def get_stop_times(self, stop: Station, line, start_time, day, - offset_times, context: ContextTypes.DEFAULT_TYPE | None = None, count=False): + offset_times, count=False): raise NotImplementedError def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, line, start_time, diff --git a/MuoVErsi/sources/trenitalia.py b/MuoVErsi/sources/trenitalia.py index ec974ec..9204827 100644 --- a/MuoVErsi/sources/trenitalia.py +++ b/MuoVErsi/sources/trenitalia.py @@ -138,7 +138,7 @@ def file_path(self): return os.path.join(parent_dir, 'trenitalia.db') def get_stop_times(self, stop: Station, line, start_time, day, - offset_times, context: ContextTypes.DEFAULT_TYPE | None = None, count=False): + offset_times, count=False): day_start = datetime.combine(day, time(0)) if start_time == '': diff --git a/MuoVErsi/stop_times_filter.py b/MuoVErsi/stop_times_filter.py index bbf78b1..3ba00b3 100644 --- a/MuoVErsi/stop_times_filter.py +++ b/MuoVErsi/stop_times_filter.py @@ -90,11 +90,10 @@ def get_times(self, db_file: Source) -> list[Liner]: context=self.context, count=True) return results - results = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times, context=self.context) + results = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times) if self.lines is None: - self.lines = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times, - context=self.context, count=True) + self.lines = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times, count=True) return results