Skip to content

Commit

Permalink
feat: remove context from get stop times (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarrco authored Sep 28, 2023
2 parents b10c26a + 0463720 commit fb50ea9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
5 changes: 5 additions & 0 deletions MuoVErsi/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,13 @@ async def send_stop_times(_, lang, db_file: Source, stop_times_filter: StopTimes
if context.user_data.get('day') != stop_times_filter.day.isoformat():
context.user_data['day'] = stop_times_filter.day.isoformat()

# add service_ids to Source instance, this way it can be accessed from get_stop_times
db_file.service_ids = context.bot_data.setdefault('service_ids', {}).setdefault(db_file.name, {})

results = stop_times_filter.get_times(db_file)

context.bot_data['service_ids'][db_file.name] = db_file.service_ids

context.user_data['lines'] = stop_times_filter.lines

text, reply_markup = stop_times_filter.format_times_text(results, _, lang)
Expand Down
32 changes: 15 additions & 17 deletions MuoVErsi/sources/GTFS/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, transport_type, emoji, session, typesense, gtfs_version=None,
super().__init__(transport_type[:3], emoji, session, typesense)
self.transport_type = transport_type
self.location = location
self.service_ids = {}

if gtfs_version:
self.gtfs_version = gtfs_version
Expand Down Expand Up @@ -172,7 +173,7 @@ def upload_stops_clusters_to_db(self, force=False) -> bool:
return True

def get_stop_times(self, stop: Station, line, start_time, day,
offset_times, context: ContextTypes.DEFAULT_TYPE | None = None, count=False):
offset_times, count=False):
cur = self.con.cursor()

route_name, route_id = line.split('-') if '-' in line else (line, '')
Expand All @@ -183,7 +184,7 @@ def get_stop_times(self, stop: Station, line, start_time, day,
line = route_id
route = 'AND r.route_id = ?'

today_service_ids = self.get_active_service_ids(day, context)
today_service_ids = self.get_active_service_ids(day)

stop_ids = stop.ids.split(',')
stop_ids = list(map(int, stop_ids))
Expand All @@ -198,7 +199,7 @@ def get_stop_times(self, stop: Station, line, start_time, day,
or_other_service = ''
yesterday_service_ids = []
if start_dt.hour < 6:
yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1), context)
yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1))
if yesterday_service_ids:
or_other_service_ids = ','.join(['?'] * len(yesterday_service_ids))
or_other_service = f'OR (dep.departure_time >= ? AND t.service_id in ({or_other_service_ids}))'
Expand Down Expand Up @@ -289,7 +290,7 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin
line = route_id
route = 'AND r.route_id = ?'

today_service_ids = self.get_active_service_ids(day, context)
today_service_ids = self.get_active_service_ids(day)
dep_stop_ids = dep_stop.ids.split(',')
dep_stop_ids = list(map(int, dep_stop_ids))

Expand All @@ -306,7 +307,7 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin
or_other_service = ''
yesterday_service_ids = []
if start_dt.hour < 6:
yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1), context)
yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1))
if yesterday_service_ids:
or_other_service_ids = ','.join(['?'] * len(yesterday_service_ids))
or_other_service = f'OR (dep.departure_time >= ? AND t.service_id in ({or_other_service_ids}))'
Expand Down Expand Up @@ -411,7 +412,7 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin

def search_lines(self, name, context: ContextTypes.DEFAULT_TYPE | None = None):
today = date.today()
service_ids = self.get_active_service_ids(today, context)
service_ids = self.get_active_service_ids(today)

cur = self.con.cursor()
query = """SELECT trips.trip_id, route_short_name, route_long_name, routes.route_id
Expand All @@ -426,16 +427,14 @@ def search_lines(self, name, context: ContextTypes.DEFAULT_TYPE | None = None):
results = cur.execute(query, (name, *service_ids)).fetchall()
return results

def get_active_service_ids(self, day: date, context: ContextTypes.DEFAULT_TYPE | None = None) -> tuple:
def get_active_service_ids(self, day: date) -> tuple:
today_ymd = day.strftime('%Y%m%d')

if context:
# access safely context.bot_data['service_ids'][self.name][today_ymd]
service_ids = context.bot_data.setdefault('service_ids', {}).setdefault(self.name, {}).setdefault(today_ymd,
None)
if service_ids:
logger.info(f'Using cached service_ids for {today_ymd}')
return service_ids
# access safely context.bot_data['service_ids'][self.name][today_ymd]
service_ids = self.service_ids.setdefault(today_ymd, None)
if service_ids:
logger.info(f'Using cached service_ids for {today_ymd}')
return service_ids

weekday = day.strftime('%A').lower()

Expand All @@ -461,9 +460,8 @@ def get_active_service_ids(self, day: date, context: ContextTypes.DEFAULT_TYPE |

service_ids = tuple(service_ids)

if context:
context.bot_data.setdefault('service_ids', {}).setdefault(self.name, {})[today_ymd] = service_ids
logger.info(f'Cached service_ids for {today_ymd}')
self.service_ids[today_ymd] = service_ids
logger.info(f'Cached service_ids for {today_ymd}')

return service_ids

Expand Down
2 changes: 1 addition & 1 deletion MuoVErsi/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sourc
return stations, found

def get_stop_times(self, stop: Station, line, start_time, day,
offset_times, context: ContextTypes.DEFAULT_TYPE | None = None, count=False):
offset_times, count=False):
raise NotImplementedError

def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, line, start_time,
Expand Down
2 changes: 1 addition & 1 deletion MuoVErsi/sources/trenitalia.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def file_path(self):
return os.path.join(parent_dir, 'trenitalia.db')

def get_stop_times(self, stop: Station, line, start_time, day,
offset_times, context: ContextTypes.DEFAULT_TYPE | None = None, count=False):
offset_times, count=False):
day_start = datetime.combine(day, time(0))

if start_time == '':
Expand Down
5 changes: 2 additions & 3 deletions MuoVErsi/stop_times_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,10 @@ def get_times(self, db_file: Source) -> list[Liner]:
context=self.context, count=True)
return results

results = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times, context=self.context)
results = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times)

if self.lines is None:
self.lines = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times,
context=self.context, count=True)
self.lines = db_file.get_stop_times(dep_stop, line, start_time, day, self.offset_times, count=True)

return results

Expand Down

0 comments on commit fb50ea9

Please sign in to comment.