From aa29657a6aaf0740e1f5096a239fd692ebf51a2d Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Sun, 29 Oct 2023 17:12:53 +0100 Subject: [PATCH 1/6] Rename "MuoVErsi" package to "server" --- alembic/env.py | 4 ++-- .../2e8b9b6298f0_create_typesense_table_for_stations.py | 2 +- alembic/versions/6c9ef3a680e3_create_stops_table.py | 2 +- run.py | 2 +- save_data.py | 8 ++++---- {MuoVErsi => server}/GTFS/__init__.py | 0 {MuoVErsi => server}/GTFS/clustering.py | 0 {MuoVErsi => server}/GTFS/models.py | 0 {MuoVErsi => server}/GTFS/source.py | 7 +++---- {MuoVErsi => server}/__init__.py | 0 {MuoVErsi => server}/base/__init__.py | 0 {MuoVErsi => server}/base/models.py | 0 {MuoVErsi => server}/base/source.py | 2 +- {MuoVErsi => server}/data/trenitalia_stations.json | 0 {MuoVErsi => server}/handlers.py | 4 ++-- {MuoVErsi => server}/helpers.py | 0 {MuoVErsi => server}/persistence.py | 0 {MuoVErsi => server}/stop_times_filter.py | 2 +- {MuoVErsi => server}/trenitalia/__init__.py | 0 {MuoVErsi => server}/trenitalia/source.py | 4 ++-- {MuoVErsi => server}/typesense.py | 0 tests/test_db.py | 3 ++- tests/test_gtfs_clustering.py | 2 +- 23 files changed, 21 insertions(+), 21 deletions(-) rename {MuoVErsi => server}/GTFS/__init__.py (100%) rename {MuoVErsi => server}/GTFS/clustering.py (100%) rename {MuoVErsi => server}/GTFS/models.py (100%) rename {MuoVErsi => server}/GTFS/source.py (99%) rename {MuoVErsi => server}/__init__.py (100%) rename {MuoVErsi => server}/base/__init__.py (100%) rename {MuoVErsi => server}/base/models.py (100%) rename {MuoVErsi => server}/base/source.py (99%) rename {MuoVErsi => server}/data/trenitalia_stations.json (100%) rename {MuoVErsi => server}/handlers.py (100%) rename {MuoVErsi => server}/helpers.py (100%) rename {MuoVErsi => server}/persistence.py (100%) rename {MuoVErsi => server}/stop_times_filter.py (99%) rename {MuoVErsi => server}/trenitalia/__init__.py (100%) rename {MuoVErsi => server}/trenitalia/source.py (99%) rename {MuoVErsi => server}/typesense.py (100%) diff --git a/alembic/env.py b/alembic/env.py index 2c4805a..4d60f29 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,8 +4,8 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool -from MuoVErsi.base.models import Base -from MuoVErsi.handlers import engine_url +from server.base.models import Base +from server.handlers import engine_url # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py b/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py index 515ba73..5797f94 100644 --- a/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py +++ b/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py @@ -7,7 +7,7 @@ """ from typesense.exceptions import ObjectNotFound -from MuoVErsi.typesense import connect_to_typesense +from server.typesense import connect_to_typesense # revision identifiers, used by Alembic. revision = '2e8b9b6298f0' diff --git a/alembic/versions/6c9ef3a680e3_create_stops_table.py b/alembic/versions/6c9ef3a680e3_create_stops_table.py index 3fec4a7..de5dcf9 100644 --- a/alembic/versions/6c9ef3a680e3_create_stops_table.py +++ b/alembic/versions/6c9ef3a680e3_create_stops_table.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy.orm import sessionmaker -from MuoVErsi.base import Station +from server.base import Station # revision identifiers, used by Alembic. revision = '6c9ef3a680e3' diff --git a/run.py b/run.py index 671bca1..1d6281d 100644 --- a/run.py +++ b/run.py @@ -1,6 +1,6 @@ import asyncio -from MuoVErsi.handlers import main +from server.handlers import main if __name__ == "__main__": asyncio.run(main()) diff --git a/save_data.py b/save_data.py index 9886c9c..e86cc18 100644 --- a/save_data.py +++ b/save_data.py @@ -2,10 +2,10 @@ from sqlalchemy.orm import sessionmaker -from MuoVErsi.handlers import engine -from MuoVErsi.trenitalia import Trenitalia -from MuoVErsi.GTFS import GTFS -from MuoVErsi.typesense import connect_to_typesense +from server.GTFS import GTFS +from server.handlers import engine +from server.trenitalia import Trenitalia +from server.typesense import connect_to_typesense logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/MuoVErsi/GTFS/__init__.py b/server/GTFS/__init__.py similarity index 100% rename from MuoVErsi/GTFS/__init__.py rename to server/GTFS/__init__.py diff --git a/MuoVErsi/GTFS/clustering.py b/server/GTFS/clustering.py similarity index 100% rename from MuoVErsi/GTFS/clustering.py rename to server/GTFS/clustering.py diff --git a/MuoVErsi/GTFS/models.py b/server/GTFS/models.py similarity index 100% rename from MuoVErsi/GTFS/models.py rename to server/GTFS/models.py diff --git a/MuoVErsi/GTFS/source.py b/server/GTFS/source.py similarity index 99% rename from MuoVErsi/GTFS/source.py rename to server/GTFS/source.py index 328f346..a3cfb6f 100644 --- a/MuoVErsi/GTFS/source.py +++ b/server/GTFS/source.py @@ -12,14 +12,13 @@ import requests from bs4 import BeautifulSoup +from sqlalchemy import select, func from tqdm import tqdm -from MuoVErsi.base import Source, Station, Stop, TripStopTime +from server.base import Source, Station, Stop, TripStopTime from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster from .models import CStop -from sqlalchemy import select, func - logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -310,7 +309,7 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim def search_lines(self, name): today = date.today() - from MuoVErsi.base import Trip + from server.base import Trip trips = self.session.execute( select(func.max(Trip.number), Trip.dest_text)\ .filter(Trip.orig_dep_date == today)\ diff --git a/MuoVErsi/__init__.py b/server/__init__.py similarity index 100% rename from MuoVErsi/__init__.py rename to server/__init__.py diff --git a/MuoVErsi/base/__init__.py b/server/base/__init__.py similarity index 100% rename from MuoVErsi/base/__init__.py rename to server/base/__init__.py diff --git a/MuoVErsi/base/models.py b/server/base/models.py similarity index 100% rename from MuoVErsi/base/models.py rename to server/base/models.py diff --git a/MuoVErsi/base/source.py b/server/base/source.py similarity index 99% rename from MuoVErsi/base/source.py rename to server/base/source.py index 725d29c..1204e3a 100644 --- a/MuoVErsi/base/source.py +++ b/server/base/source.py @@ -366,7 +366,7 @@ def get_stop_times_between_stops(self, dep_station: Station, arr_station: Statio raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name, arr_time=a_arr_time, orig_dep_date=raw_stop_time.orig_dep_date) - from MuoVErsi.trenitalia import TrenitaliaRoute + from server.trenitalia import TrenitaliaRoute route = TrenitaliaRoute(d_stop_time, a_stop_time) directions.append(Direction([route])) diff --git a/MuoVErsi/data/trenitalia_stations.json b/server/data/trenitalia_stations.json similarity index 100% rename from MuoVErsi/data/trenitalia_stations.json rename to server/data/trenitalia_stations.json diff --git a/MuoVErsi/handlers.py b/server/handlers.py similarity index 100% rename from MuoVErsi/handlers.py rename to server/handlers.py index d66f7d8..bdb9ba5 100644 --- a/MuoVErsi/handlers.py +++ b/server/handlers.py @@ -24,11 +24,11 @@ MessageHandler, filters, CallbackQueryHandler, ) -from .persistence import SQLitePersistence from .GTFS import GTFS from .base import Source -from .trenitalia import Trenitalia +from .persistence import SQLitePersistence from .stop_times_filter import StopTimesFilter +from .trenitalia import Trenitalia from .typesense import connect_to_typesense logging.basicConfig( diff --git a/MuoVErsi/helpers.py b/server/helpers.py similarity index 100% rename from MuoVErsi/helpers.py rename to server/helpers.py diff --git a/MuoVErsi/persistence.py b/server/persistence.py similarity index 100% rename from MuoVErsi/persistence.py rename to server/persistence.py diff --git a/MuoVErsi/stop_times_filter.py b/server/stop_times_filter.py similarity index 99% rename from MuoVErsi/stop_times_filter.py rename to server/stop_times_filter.py index c192f55..73da213 100644 --- a/MuoVErsi/stop_times_filter.py +++ b/server/stop_times_filter.py @@ -5,7 +5,7 @@ from telegram import InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes -from MuoVErsi.base import Source, Liner, Station +from server.base import Source, Liner, Station logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/MuoVErsi/trenitalia/__init__.py b/server/trenitalia/__init__.py similarity index 100% rename from MuoVErsi/trenitalia/__init__.py rename to server/trenitalia/__init__.py diff --git a/MuoVErsi/trenitalia/source.py b/server/trenitalia/source.py similarity index 99% rename from MuoVErsi/trenitalia/source.py rename to server/trenitalia/source.py index cfb7c32..859666e 100644 --- a/MuoVErsi/trenitalia/source.py +++ b/server/trenitalia/source.py @@ -1,12 +1,12 @@ import json import math import os -from pytz import timezone import requests +from pytz import timezone from tqdm import tqdm -from MuoVErsi.base import * +from server.base import * logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/MuoVErsi/typesense.py b/server/typesense.py similarity index 100% rename from MuoVErsi/typesense.py rename to server/typesense.py diff --git a/tests/test_db.py b/tests/test_db.py index 95e622d..5105171 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,7 +1,8 @@ from datetime import date, datetime, time + import pytest -from MuoVErsi.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop +from server.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop @pytest.fixture diff --git a/tests/test_gtfs_clustering.py b/tests/test_gtfs_clustering.py index 343d9c5..9fb62b9 100644 --- a/tests/test_gtfs_clustering.py +++ b/tests/test_gtfs_clustering.py @@ -1,6 +1,6 @@ import pytest -from MuoVErsi.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster +from server.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster @pytest.fixture From d7766021aed83bd98784577e566baad5780d4ae7 Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Sun, 29 Oct 2023 19:04:41 +0100 Subject: [PATCH 2/6] Separate bot related stuff into "tgbot" --- config.py | 19 + server/handlers.py | 586 +------------------------ tgbot/__init__.py | 0 tgbot/handlers.py | 576 ++++++++++++++++++++++++ {server => tgbot}/persistence.py | 0 tgbot/routes.py | 22 + {server => tgbot}/stop_times_filter.py | 0 7 files changed, 625 insertions(+), 578 deletions(-) create mode 100644 config.py create mode 100644 tgbot/__init__.py create mode 100644 tgbot/handlers.py rename {server => tgbot}/persistence.py (100%) create mode 100644 tgbot/routes.py rename {server => tgbot}/stop_times_filter.py (100%) diff --git a/config.py b/config.py new file mode 100644 index 0000000..dae42cf --- /dev/null +++ b/config.py @@ -0,0 +1,19 @@ +import logging +import os + +import yaml + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + +current_dir = os.path.abspath(os.path.dirname(__file__)) + +config_path = os.path.join(current_dir, 'config.yaml') +with open(config_path, 'r') as config_file: + try: + config = yaml.safe_load(config_file) + logger.info(config) + except yaml.YAMLError as err: + logger.error(err) diff --git a/server/handlers.py b/server/handlers.py index bdb9ba5..dbba95f 100644 --- a/server/handlers.py +++ b/server/handlers.py @@ -1,33 +1,16 @@ -import gettext import logging -import os -import sys -from datetime import datetime, timedelta, date -import requests import uvicorn -import yaml -from babel.dates import format_date from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route -from telegram import ReplyKeyboardMarkup, ReplyKeyboardRemove, Update, KeyboardButton, InlineKeyboardMarkup, \ - InlineKeyboardButton, Bot -from telegram.ext import ( - Application, - CommandHandler, - ContextTypes, - ConversationHandler, - MessageHandler, - filters, CallbackQueryHandler, ) +from config import config +from tgbot.handlers import setup as tgbot_setup from .GTFS import GTFS -from .base import Source -from .persistence import SQLitePersistence -from .stop_times_filter import StopTimesFilter from .trenitalia import Trenitalia from .typesense import connect_to_typesense @@ -36,578 +19,29 @@ ) logger = logging.getLogger(__name__) -current_dir = os.path.abspath(os.path.dirname(__file__)) -parent_dir = os.path.abspath(current_dir + "/../") -thismodule = sys.modules[__name__] -thismodule.sources = {} -thismodule.persistence = SQLitePersistence() - - -SEARCH_STOP, SPECIFY_LINE, SEARCH_LINE, SHOW_LINE, SHOW_STOP = range(5) - - -localedir = os.path.join(parent_dir, 'locales') - -config_path = os.path.join(parent_dir, 'config.yaml') -with open(config_path, 'r') as config_file: - try: - config = yaml.safe_load(config_file) - logger.info(config) - except yaml.YAMLError as err: - logger.error(err) - engine_url = f"postgresql://{config['PGUSER']}:{config['PGPASSWORD']}@{config['PGHOST']}:{config['PGPORT']}/" \ f"{config['PGDATABASE']}" engine = create_engine(engine_url) -def clean_user_data(context, keep_transport_type=True): - context.user_data.pop('query_data', None) - context.user_data.pop('lines', None) - context.user_data.pop('dep_stop_ids', None) - context.user_data.pop('arr_stop_ids', None) - context.user_data.pop('dep_cluster_name', None) - context.user_data.pop('arr_cluster_name', None) - if not keep_transport_type: - context.user_data.pop('transport_type', None) - - -async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - clean_user_data(context, False) - await update.message.reply_text(_('welcome') + "\n\n" + _('home') % (_('stop'), _('line')), disable_notification=True) - - -async def announce(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - if config.get('ADMIN_TG_ID') != update.effective_user.id: - return - - persistence: SQLitePersistence = thismodule.persistence - user_ids = persistence.get_all_users() - text = update.message.text[10:] - for user_id in user_ids: - try: - await context.bot.send_message(user_id, text, parse_mode='HTML', disable_notification=True) - except Exception as e: - logger.error(e) - - -async def choose_service(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - command_text = update.message.text[1:] - - if command_text == 'fermata' or command_text == 'stop': - command = 'fermata' - elif command_text == 'linea' or command_text == 'line': - command = 'linea' - else: - return ConversationHandler.END - - clean_user_data(context) - - if command == 'fermata': - reply_keyboard = [[KeyboardButton(_('send_location'), request_location=True)]] - reply_keyboard_markup = ReplyKeyboardMarkup( - reply_keyboard, resize_keyboard=True, is_persistent=True - ) - await update.message.reply_text(_('insert_stop'), reply_markup=reply_keyboard_markup, parse_mode='HTML', disable_notification=True) - return SEARCH_STOP - - if context.user_data.get('transport_type'): - return await specify_line(update, context) - - inline_keyboard = [[]] - - for source in thismodule.sources: - inline_keyboard[0].append(InlineKeyboardButton(_(source), callback_data="T0" + source)) - - await update.message.reply_text( - _('choose_service'), - reply_markup=InlineKeyboardMarkup(inline_keyboard), - disable_notification=True - ) - - return SPECIFY_LINE - - -async def specify_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - send_second_message = True - if update.callback_query: - query = update.callback_query - if query.data[1] == '1': - send_second_message = False - chat_id = query.message.chat_id - short_transport_type = query.data[2:] - context.user_data['transport_type'] = short_transport_type - bot = query.get_bot() - await query.answer('') - else: - short_transport_type = context.user_data['transport_type'] - bot = update.message.get_bot() - chat_id = update.message.chat_id - - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - others_sources = [source for source in thismodule.sources if source != short_transport_type] - - inline_keyboard = [[]] - - for source in others_sources: - inline_keyboard[0].append(InlineKeyboardButton(_('change_service') % _(source), callback_data="T1" + source)) - - transport_type = _(short_transport_type) - - keyboard = InlineKeyboardMarkup(inline_keyboard) - - if update.callback_query: - await update.callback_query.answer() - await update.callback_query.edit_message_text(_('service_selected') % transport_type, reply_markup=keyboard) - else: - await bot.send_message(chat_id, _('service_selected') % transport_type, reply_markup=keyboard, disable_notification=True) - - if send_second_message: - await bot.send_message(chat_id, _('insert_line'), reply_markup=ReplyKeyboardRemove(), disable_notification=True) - - return SEARCH_LINE - - -async def search_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - db_file: Source = thismodule.sources[context.user_data.get('transport_type', 'aut')] - - limit = 4 - - saved_dep_stop_ids = 'dep_stop_ids' not in context.user_data - - if update.callback_query: - text, lat, lon, page = update.callback_query.data[1:].split('/') - page = int(page) - else: - text, lat, lon, page = '', '', '', 1 - message = update.message - if message.location: - lat = message.location.latitude - lon = message.location.longitude - else: - text = message.text - - if lat == '' and lon == '': - stops_clusters, count = db_file.search_stops(name=text, all_sources=saved_dep_stop_ids, page=page, limit=limit) - else: - stops_clusters, count = db_file.search_stops(lat=lat, lon=lon, all_sources=saved_dep_stop_ids, page=page, limit=limit) - - if not stops_clusters: - await update.message.reply_text(_('stop_not_found'), disable_notification=True) - return SEARCH_STOP - - buttons = [[InlineKeyboardButton(f'{cluster.name} {thismodule.sources[cluster.source].emoji}', callback_data=f'S{cluster.id}-{cluster.source}')] - for cluster in stops_clusters] - - paging_buttons = [] - if page > 1: - paging_buttons.append(InlineKeyboardButton('<', callback_data=f'F{text}/{lat}/{lon}/{page - 1}')) - if page * limit < count: - paging_buttons.append(InlineKeyboardButton('>', callback_data=f'F{text}/{lat}/{lon}/{page + 1}')) - - if paging_buttons: - buttons.append(paging_buttons) - - if update.callback_query: - await update.callback_query.answer() - await update.callback_query.edit_message_text( - _('choose_stop'), - reply_markup=InlineKeyboardMarkup(buttons) - ) - else: - await update.message.reply_text( - _('choose_stop'), - reply_markup=InlineKeyboardMarkup(buttons), - disable_notification=True - ) - - return SHOW_STOP - - -async def send_stop_times(_, lang, db_file: Source, stop_times_filter: StopTimesFilter, chat_id, message_id, bot: Bot, - context: ContextTypes.DEFAULT_TYPE) -> int: - context.user_data['query_data'] = stop_times_filter.query_data() - - if stop_times_filter.first_time: - context.user_data.pop('lines', None) - - stop_times_filter.lines = context.user_data.get('lines') - - if context.user_data.get('day') != stop_times_filter.day.isoformat(): - context.user_data['day'] = stop_times_filter.day.isoformat() - - # add service_ids to Source instance, this way it can be accessed from get_stop_times - db_file.service_ids = context.bot_data.setdefault('service_ids', {}).setdefault(db_file.name, {}) - - results = stop_times_filter.get_times(db_file) - - context.bot_data['service_ids'][db_file.name] = db_file.service_ids - - context.user_data['lines'] = stop_times_filter.lines - - text, reply_markup = stop_times_filter.format_times_text(results, _, lang) - - if message_id: - await bot.edit_message_text(text, chat_id, message_id, reply_markup=reply_markup, parse_mode='HTML') - else: - await bot.send_message(chat_id, text=text, reply_markup=reply_markup, parse_mode='HTML', disable_notification=True) - - if stop_times_filter.first_time: - if stop_times_filter.arr_stop_ids: - text = '' + _('send_new_arr_stop') + '' - else: - text = '' + _('send_arr_stop') + '' - - reply_keyboard = [[KeyboardButton(_('send_location'), request_location=True)]] - reply_keyboard_markup = ReplyKeyboardMarkup( - reply_keyboard, resize_keyboard=True, is_persistent=True - ) - - await bot.send_message(chat_id, text, disable_notification=True, - reply_markup=reply_keyboard_markup, parse_mode='HTML') - - return SHOW_STOP - - -async def change_day_show_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - db_file = thismodule.sources[context.user_data['transport_type']] - - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - del context.user_data['lines'] - dep_stop_ids = context.user_data.get('dep_stop_ids') - arr_stop_ids = context.user_data.get('arr_stop_ids') - dep_cluster_name = context.user_data.get('dep_cluster_name') - arr_cluster_name = context.user_data.get('arr_cluster_name') - stop_times_filter = StopTimesFilter(context, db_file, dep_stop_ids=dep_stop_ids, - query_data=context.user_data['query_data'], - arr_stop_ids=arr_stop_ids, dep_cluster_name=dep_cluster_name, - arr_cluster_name=arr_cluster_name) - if update.message.text == _('minus_day'): - stop_times_filter.day -= timedelta(days=1) - else: - stop_times_filter.day += timedelta(days=1) - stop_times_filter.start_time = '' - stop_times_filter.offset_times = 0 - - return await send_stop_times(_, lang, db_file, stop_times_filter, update.effective_chat.id, None, update.get_bot(), - context) - - -async def show_stop_from_id(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - now = datetime.now() - - text = update.message.text if update.message else update.callback_query.data - - message_id = None - - if update.callback_query: - message_id = update.callback_query.message.message_id - - stop_ref, line = text[1:].split('/') if '/' in text else (text[1:], '') - if '-' in stop_ref: - stop_ref, source_name = stop_ref.split('-') - db_file: Source = thismodule.sources[source_name] - context.user_data['transport_type'] = source_name - else: - db_file = thismodule.sources[context.user_data['transport_type']] - - station = db_file.get_stop_from_ref(stop_ref) - cluster_name = station.name - stop_ids = ','.join([stop.id for stop in station.stops]) - saved_dep_stop_ids = context.user_data.get('dep_stop_ids') - saved_dep_cluster_name = context.user_data.get('dep_cluster_name') - - if saved_dep_stop_ids: - stop_times_filter = StopTimesFilter(context, db_file, saved_dep_stop_ids, now.date(), line, now.time(), - arr_stop_ids=stop_ids, - arr_cluster_name=cluster_name, dep_cluster_name=saved_dep_cluster_name, - first_time=True) - context.user_data['arr_stop_ids'] = stop_ids - context.user_data['arr_cluster_name'] = cluster_name - else: - stop_times_filter = StopTimesFilter(context, db_file, stop_ids, now.date(), line, now.time(), - dep_cluster_name=cluster_name, - first_time=True) - context.user_data['dep_stop_ids'] = stop_ids - context.user_data['dep_cluster_name'] = cluster_name - - new_state = await send_stop_times(_, lang, db_file, stop_times_filter, update.effective_chat.id, - message_id, update.get_bot(), context) - - if update.callback_query: - await update.callback_query.answer() - - return new_state - - -async def filter_show_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - db_file = thismodule.sources[context.user_data['transport_type']] - - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - query = update.callback_query - logger.info("Query data %s", query.data) - dep_stop_ids = context.user_data.get('dep_stop_ids') - arr_stop_ids = context.user_data.get('arr_stop_ids') - dep_cluster_name = context.user_data.get('dep_cluster_name') - arr_cluster_name = context.user_data.get('arr_cluster_name') - stop_times_filter = StopTimesFilter(context, db_file, dep_stop_ids=dep_stop_ids, query_data=query.data, - arr_stop_ids=arr_stop_ids, - dep_cluster_name=dep_cluster_name, arr_cluster_name=arr_cluster_name) - message_id = query.message.message_id - - chat_id = update.callback_query.message.chat_id - bot = update.get_bot() - - new_state = await send_stop_times(_, lang, db_file, stop_times_filter, chat_id, message_id, bot, context) - - await query.answer('') - - return new_state - - -async def trip_view(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - source: Source = thismodule.sources[context.user_data['transport_type']] - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - query_data = context.user_data['query_data'] - dep_stop_ids = context.user_data['dep_stop_ids'] - dep_cluster_name = context.user_data['dep_cluster_name'] - arr_stop_ids = context.user_data.get('arr_stop_ids') - arr_cluster_name = context.user_data.get('arr_cluster_name') - - stop_times_filter = StopTimesFilter(context, source, query_data=query_data, dep_stop_ids=dep_stop_ids, - dep_cluster_name=dep_cluster_name, arr_stop_ids=arr_stop_ids, - arr_cluster_name=arr_cluster_name) - if update.message: - text, all_stops = update.message.text, False - else: - text, all_stops = update.callback_query.data, True - trip_id = text[1:] - results = source.get_stops_from_trip_id(trip_id, stop_times_filter.day) - - line = results[0].route_name - text = '' + format_date(stop_times_filter.day, 'EEEE d MMMM', locale=lang) + ' - ' + _( - 'line') + ' ' + line + f' {trip_id}' - - dep_stop_index = 0 - arr_stop_index = len(results) - 1 - if not all_stops: - dep_stop_ids = stop_times_filter.dep_stop_ids.split(',') - try: - dep_stop_index = next(i for i, v in enumerate(results) if v.station.id in dep_stop_ids) - except StopIteration: - logger.warning('No departure stop found') - if arr_cluster_name: - arr_stop_ids = stop_times_filter.arr_stop_ids.split(',') - try: - arr_stop_index = dep_stop_index + next( - i for i, v in enumerate(results[dep_stop_index:]) if str(v.station.id) in arr_stop_ids) - except StopIteration: - logger.warning('No arrival stop found') - - platform_text = _(f'{source.name}_platform') - - are_dep_and_arr_times_equal = all( - result.arr_time == result.dep_time for result in results[dep_stop_index:arr_stop_index + 1]) - - for i, result in enumerate(results[dep_stop_index:arr_stop_index + 1]): - arr_time = result.arr_time.strftime('%H:%M') if result.arr_time else '' - dep_time = result.dep_time.strftime('%H:%M') if result.dep_time else '' - - if are_dep_and_arr_times_equal: - text += f'\n{arr_time} {result.station.station.name}' - else: - if i == 0: - text += f'\n{result.station.station.name} {dep_time}' - elif i == arr_stop_index: - text += f'\n{arr_time} {result.station.station.name}' - else: - text += f'\n{arr_time} {result.station.station.name} {dep_time}' - - if result.platform: - text += f' ({platform_text} {result.platform})' - - buttons = [InlineKeyboardButton(_('back'), callback_data=context.user_data['query_data'])] - - if not all_stops: - buttons.append(InlineKeyboardButton(_('all_stops'), callback_data=f'M{trip_id}')) - - reply_markup = InlineKeyboardMarkup([buttons]) - if update.message: - await update.message.reply_text(text=text, reply_markup=reply_markup, parse_mode='HTML', disable_notification=True) - else: - await update.callback_query.edit_message_text(text=text, reply_markup=reply_markup, parse_mode='HTML') - return SHOW_STOP - - -async def search_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - db_file: Source = thismodule.sources[context.user_data['transport_type']] - - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - try: - lines = db_file.search_lines(update.message.text) - except NotImplementedError: - await update.message.reply_text(_('not_implemented'), disable_notification=True) - return ConversationHandler.END - - keyboard = [[InlineKeyboardButton(line[2], callback_data=f'L{line[0]}/{line[1]}')] for line in lines] - inline_markup = InlineKeyboardMarkup(keyboard) - - await update.message.reply_text(_('choose_line'), reply_markup=inline_markup, disable_notification=True) - - return SHOW_LINE - - -async def show_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - source: Source = thismodule.sources[context.user_data['transport_type']] - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - query = update.callback_query - - trip_id, line = query.data[1:].split('/') - - day = date.today() - stops = source.get_stops_from_trip_id(trip_id, day) - - text = _('stops') + ':\n' - - inline_buttons = [] - - for stop in stops: - station = stop.station.station - stop_id = station.id - stop_name = station.name - inline_buttons.append([InlineKeyboardButton(stop_name, callback_data=f'S{stop_id}/{line}')]) - - await query.edit_message_text(text=text, reply_markup=InlineKeyboardMarkup(inline_buttons)) - await query.answer('') - - return SHOW_STOP - - -async def cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - clean_user_data(context) - - lang = 'it' if update.effective_user.language_code == 'it' else 'en' - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - - await update.message.reply_text(_('cancel') + "\n\n" + _('home') % (_('stop'), _('line')), reply_markup=ReplyKeyboardRemove(), disable_notification=True) - - return ConversationHandler.END - - async def main() -> None: DEV = config.get('DEV', False) session = sessionmaker(bind=engine)() typesense = connect_to_typesense() - thismodule.sources = { + sources = { 'aut': GTFS('automobilistico', '🚌', session, typesense, dev=DEV), 'nav': GTFS('navigazione', '⛴️', session, typesense, dev=DEV), 'treni': Trenitalia(session, typesense) } - for source in thismodule.sources.values(): + for source in sources.values(): source.sync_stations_typesense(source.get_source_stations()) - application = Application.builder().token(config['TOKEN']).persistence(persistence=thismodule.persistence).build() + application, tgbot_routes = await tgbot_setup(config, sources) - langs = [f for f in os.listdir(localedir) if os.path.isdir(os.path.join(localedir, f))] - default_lang = 'en' - - for lang in langs: - trans = gettext.translation('messages', localedir, languages=[lang]) - _ = trans.gettext - language_code = lang if lang != default_lang else '' - r = requests.post(f'https://api.telegram.org/bot{config["TOKEN"]}/setMyCommands', json={ - 'commands': [ - {'command': _('stop'), 'description': _('search_by_stop')}, - {'command': _('line'), 'description': _('search_by_line')} - ], - 'language_code': language_code - }) - - conv_handler = ConversationHandler( - name='orari', - entry_points=[MessageHandler(filters.Regex(r'^\/[a-z]+$'), choose_service)], - states={ - SEARCH_STOP: [ - MessageHandler((filters.TEXT | filters.LOCATION) & (~filters.COMMAND), search_stop) - ], - SPECIFY_LINE: [CallbackQueryHandler(specify_line, r'^T')], - SEARCH_LINE: [ - MessageHandler(filters.TEXT & (~filters.COMMAND), search_line), - CallbackQueryHandler(specify_line, r'^T') - ], - SHOW_LINE: [CallbackQueryHandler(show_line, r'^L')], - SHOW_STOP: [ - CallbackQueryHandler(filter_show_stop, r'^Q'), - MessageHandler(filters.Regex(r'^\/[0-9]+$'), trip_view), - CallbackQueryHandler(trip_view, r'^M'), - CallbackQueryHandler(show_stop_from_id, r'^S'), - MessageHandler(filters.Regex(r'^\-|\+1[a-z]$'), change_day_show_stop), - MessageHandler((filters.TEXT | filters.LOCATION) & (~filters.COMMAND), search_stop), - CallbackQueryHandler(search_stop, r'^F') - ] - }, - fallbacks=[CommandHandler("cancel", cancel), MessageHandler(filters.Regex(r'^\/[a-z]+$'), choose_service)], - persistent=True - ) - - application.add_handler(CommandHandler("start", start)) - application.add_handler(MessageHandler(filters.Regex(r'^\/announce '), announce)) - application.add_handler(conv_handler) - - webhook_url = config['WEBHOOK_URL'] + '/tg_bot_webhook' - bot: Bot = application.bot - - if DEV: - await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['SECRET_TOKEN']) - else: - await bot.set_webhook(webhook_url, secret_token=config['SECRET_TOKEN']) - - async def telegram(request: Request) -> Response: - if request.headers['X-Telegram-Bot-Api-Secret-Token'] != config['SECRET_TOKEN']: - return Response(status_code=403) - await application.update_queue.put( - Update.de_json(data=await request.json(), bot=application.bot) - ) - return Response() - async def home(request: Request) -> Response: - sources = thismodule.sources text_response = '' try: @@ -616,7 +50,6 @@ async def home(request: Request) -> Response: return Response(status_code=500) else: text_response += '

Postgres connection OK

' - text_response += '' return Response(text_response) - starlette_app = Starlette( - routes=[ - Route("/", home), - Route("/tg_bot_webhook", telegram, methods=["POST"]) - ] - ) + routes = [Route("/", home)] + routes += tgbot_routes + starlette_app = Starlette(routes=routes) if DEV: webserver = uvicorn.Server( diff --git a/tgbot/__init__.py b/tgbot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tgbot/handlers.py b/tgbot/handlers.py new file mode 100644 index 0000000..ed5c91b --- /dev/null +++ b/tgbot/handlers.py @@ -0,0 +1,576 @@ +import gettext +import logging +import os +import sys +from datetime import timedelta, datetime, date + +import requests +from babel.dates import format_date +from telegram import Update, KeyboardButton, ReplyKeyboardMarkup, InlineKeyboardButton, InlineKeyboardMarkup, \ + ReplyKeyboardRemove, Bot +from telegram.ext import ( + Application, + CommandHandler, + ConversationHandler, + MessageHandler, + filters, CallbackQueryHandler, ) +from telegram.ext import ContextTypes + +from config import config +from server.base import Source +from .persistence import SQLitePersistence +from .routes import get_routes +from .stop_times_filter import StopTimesFilter + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + +current_dir = os.path.abspath(os.path.dirname(__file__)) +parent_dir = os.path.abspath(current_dir + "/../") +thismodule = sys.modules[__name__] +thismodule.sources = {} +thismodule.persistence = SQLitePersistence() + +SEARCH_STOP, SPECIFY_LINE, SEARCH_LINE, SHOW_LINE, SHOW_STOP = range(5) + +localedir = os.path.join(parent_dir, 'locales') + + +def clean_user_data(context, keep_transport_type=True): + context.user_data.pop('query_data', None) + context.user_data.pop('lines', None) + context.user_data.pop('dep_stop_ids', None) + context.user_data.pop('arr_stop_ids', None) + context.user_data.pop('dep_cluster_name', None) + context.user_data.pop('arr_cluster_name', None) + if not keep_transport_type: + context.user_data.pop('transport_type', None) + + +async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + clean_user_data(context, False) + await update.message.reply_text(_('welcome') + "\n\n" + _('home') % (_('stop'), _('line')), + disable_notification=True) + + +async def announce(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + if config.get('ADMIN_TG_ID') != update.effective_user.id: + return + + persistence: SQLitePersistence = thismodule.persistence + user_ids = persistence.get_all_users() + text = update.message.text[10:] + for user_id in user_ids: + try: + await context.bot.send_message(user_id, text, parse_mode='HTML', disable_notification=True) + except Exception as e: + logger.error(e) + + +async def choose_service(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + command_text = update.message.text[1:] + + if command_text == 'fermata' or command_text == 'stop': + command = 'fermata' + elif command_text == 'linea' or command_text == 'line': + command = 'linea' + else: + return ConversationHandler.END + + clean_user_data(context) + + if command == 'fermata': + reply_keyboard = [[KeyboardButton(_('send_location'), request_location=True)]] + reply_keyboard_markup = ReplyKeyboardMarkup( + reply_keyboard, resize_keyboard=True, is_persistent=True + ) + await update.message.reply_text(_('insert_stop'), reply_markup=reply_keyboard_markup, parse_mode='HTML', + disable_notification=True) + return SEARCH_STOP + + if context.user_data.get('transport_type'): + return await specify_line(update, context) + + inline_keyboard = [[]] + + for source in thismodule.sources: + inline_keyboard[0].append(InlineKeyboardButton(_(source), callback_data="T0" + source)) + + await update.message.reply_text( + _('choose_service'), + reply_markup=InlineKeyboardMarkup(inline_keyboard), + disable_notification=True + ) + + return SPECIFY_LINE + + +async def specify_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + send_second_message = True + if update.callback_query: + query = update.callback_query + if query.data[1] == '1': + send_second_message = False + chat_id = query.message.chat_id + short_transport_type = query.data[2:] + context.user_data['transport_type'] = short_transport_type + bot = query.get_bot() + await query.answer('') + else: + short_transport_type = context.user_data['transport_type'] + bot = update.message.get_bot() + chat_id = update.message.chat_id + + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + others_sources = [source for source in thismodule.sources if source != short_transport_type] + + inline_keyboard = [[]] + + for source in others_sources: + inline_keyboard[0].append(InlineKeyboardButton(_('change_service') % _(source), callback_data="T1" + source)) + + transport_type = _(short_transport_type) + + keyboard = InlineKeyboardMarkup(inline_keyboard) + + if update.callback_query: + await update.callback_query.answer() + await update.callback_query.edit_message_text(_('service_selected') % transport_type, reply_markup=keyboard) + else: + await bot.send_message(chat_id, _('service_selected') % transport_type, reply_markup=keyboard, + disable_notification=True) + + if send_second_message: + await bot.send_message(chat_id, _('insert_line'), reply_markup=ReplyKeyboardRemove(), disable_notification=True) + + return SEARCH_LINE + + +async def search_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + db_file: Source = thismodule.sources[context.user_data.get('transport_type', 'aut')] + + limit = 4 + + saved_dep_stop_ids = 'dep_stop_ids' not in context.user_data + + if update.callback_query: + text, lat, lon, page = update.callback_query.data[1:].split('/') + page = int(page) + else: + text, lat, lon, page = '', '', '', 1 + message = update.message + if message.location: + lat = message.location.latitude + lon = message.location.longitude + else: + text = message.text + + if lat == '' and lon == '': + stops_clusters, count = db_file.search_stops(name=text, all_sources=saved_dep_stop_ids, page=page, limit=limit) + else: + stops_clusters, count = db_file.search_stops(lat=lat, lon=lon, all_sources=saved_dep_stop_ids, page=page, + limit=limit) + + if not stops_clusters: + await update.message.reply_text(_('stop_not_found'), disable_notification=True) + return SEARCH_STOP + + buttons = [[InlineKeyboardButton(f'{cluster.name} {thismodule.sources[cluster.source].emoji}', + callback_data=f'S{cluster.id}-{cluster.source}')] + for cluster in stops_clusters] + + paging_buttons = [] + if page > 1: + paging_buttons.append(InlineKeyboardButton('<', callback_data=f'F{text}/{lat}/{lon}/{page - 1}')) + if page * limit < count: + paging_buttons.append(InlineKeyboardButton('>', callback_data=f'F{text}/{lat}/{lon}/{page + 1}')) + + if paging_buttons: + buttons.append(paging_buttons) + + if update.callback_query: + await update.callback_query.answer() + await update.callback_query.edit_message_text( + _('choose_stop'), + reply_markup=InlineKeyboardMarkup(buttons) + ) + else: + await update.message.reply_text( + _('choose_stop'), + reply_markup=InlineKeyboardMarkup(buttons), + disable_notification=True + ) + + return SHOW_STOP + + +async def send_stop_times(_, lang, db_file: Source, stop_times_filter: StopTimesFilter, chat_id, message_id, bot: Bot, + context: ContextTypes.DEFAULT_TYPE) -> int: + context.user_data['query_data'] = stop_times_filter.query_data() + + if stop_times_filter.first_time: + context.user_data.pop('lines', None) + + stop_times_filter.lines = context.user_data.get('lines') + + if context.user_data.get('day') != stop_times_filter.day.isoformat(): + context.user_data['day'] = stop_times_filter.day.isoformat() + + # add service_ids to Source instance, this way it can be accessed from get_stop_times + db_file.service_ids = context.bot_data.setdefault('service_ids', {}).setdefault(db_file.name, {}) + + results = stop_times_filter.get_times(db_file) + + context.bot_data['service_ids'][db_file.name] = db_file.service_ids + + context.user_data['lines'] = stop_times_filter.lines + + text, reply_markup = stop_times_filter.format_times_text(results, _, lang) + + if message_id: + await bot.edit_message_text(text, chat_id, message_id, reply_markup=reply_markup, parse_mode='HTML') + else: + await bot.send_message(chat_id, text=text, reply_markup=reply_markup, parse_mode='HTML', + disable_notification=True) + + if stop_times_filter.first_time: + if stop_times_filter.arr_stop_ids: + text = '' + _('send_new_arr_stop') + '' + else: + text = '' + _('send_arr_stop') + '' + + reply_keyboard = [[KeyboardButton(_('send_location'), request_location=True)]] + reply_keyboard_markup = ReplyKeyboardMarkup( + reply_keyboard, resize_keyboard=True, is_persistent=True + ) + + await bot.send_message(chat_id, text, disable_notification=True, + reply_markup=reply_keyboard_markup, parse_mode='HTML') + + return SHOW_STOP + + +async def change_day_show_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + db_file = thismodule.sources[context.user_data['transport_type']] + + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + del context.user_data['lines'] + dep_stop_ids = context.user_data.get('dep_stop_ids') + arr_stop_ids = context.user_data.get('arr_stop_ids') + dep_cluster_name = context.user_data.get('dep_cluster_name') + arr_cluster_name = context.user_data.get('arr_cluster_name') + stop_times_filter = StopTimesFilter(context, db_file, dep_stop_ids=dep_stop_ids, + query_data=context.user_data['query_data'], + arr_stop_ids=arr_stop_ids, dep_cluster_name=dep_cluster_name, + arr_cluster_name=arr_cluster_name) + if update.message.text == _('minus_day'): + stop_times_filter.day -= timedelta(days=1) + else: + stop_times_filter.day += timedelta(days=1) + stop_times_filter.start_time = '' + stop_times_filter.offset_times = 0 + + return await send_stop_times(_, lang, db_file, stop_times_filter, update.effective_chat.id, None, update.get_bot(), + context) + + +async def show_stop_from_id(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + now = datetime.now() + + text = update.message.text if update.message else update.callback_query.data + + message_id = None + + if update.callback_query: + message_id = update.callback_query.message.message_id + + stop_ref, line = text[1:].split('/') if '/' in text else (text[1:], '') + if '-' in stop_ref: + stop_ref, source_name = stop_ref.split('-') + db_file: Source = thismodule.sources[source_name] + context.user_data['transport_type'] = source_name + else: + db_file = thismodule.sources[context.user_data['transport_type']] + + station = db_file.get_stop_from_ref(stop_ref) + cluster_name = station.name + stop_ids = ','.join([stop.id for stop in station.stops]) + saved_dep_stop_ids = context.user_data.get('dep_stop_ids') + saved_dep_cluster_name = context.user_data.get('dep_cluster_name') + + if saved_dep_stop_ids: + stop_times_filter = StopTimesFilter(context, db_file, saved_dep_stop_ids, now.date(), line, now.time(), + arr_stop_ids=stop_ids, + arr_cluster_name=cluster_name, dep_cluster_name=saved_dep_cluster_name, + first_time=True) + context.user_data['arr_stop_ids'] = stop_ids + context.user_data['arr_cluster_name'] = cluster_name + else: + stop_times_filter = StopTimesFilter(context, db_file, stop_ids, now.date(), line, now.time(), + dep_cluster_name=cluster_name, + first_time=True) + context.user_data['dep_stop_ids'] = stop_ids + context.user_data['dep_cluster_name'] = cluster_name + + new_state = await send_stop_times(_, lang, db_file, stop_times_filter, update.effective_chat.id, + message_id, update.get_bot(), context) + + if update.callback_query: + await update.callback_query.answer() + + return new_state + + +async def filter_show_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + db_file = thismodule.sources[context.user_data['transport_type']] + + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + query = update.callback_query + logger.info("Query data %s", query.data) + dep_stop_ids = context.user_data.get('dep_stop_ids') + arr_stop_ids = context.user_data.get('arr_stop_ids') + dep_cluster_name = context.user_data.get('dep_cluster_name') + arr_cluster_name = context.user_data.get('arr_cluster_name') + stop_times_filter = StopTimesFilter(context, db_file, dep_stop_ids=dep_stop_ids, query_data=query.data, + arr_stop_ids=arr_stop_ids, + dep_cluster_name=dep_cluster_name, arr_cluster_name=arr_cluster_name) + message_id = query.message.message_id + + chat_id = update.callback_query.message.chat_id + bot = update.get_bot() + + new_state = await send_stop_times(_, lang, db_file, stop_times_filter, chat_id, message_id, bot, context) + + await query.answer('') + + return new_state + + +async def trip_view(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + source: Source = thismodule.sources[context.user_data['transport_type']] + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + query_data = context.user_data['query_data'] + dep_stop_ids = context.user_data['dep_stop_ids'] + dep_cluster_name = context.user_data['dep_cluster_name'] + arr_stop_ids = context.user_data.get('arr_stop_ids') + arr_cluster_name = context.user_data.get('arr_cluster_name') + + stop_times_filter = StopTimesFilter(context, source, query_data=query_data, dep_stop_ids=dep_stop_ids, + dep_cluster_name=dep_cluster_name, arr_stop_ids=arr_stop_ids, + arr_cluster_name=arr_cluster_name) + if update.message: + text, all_stops = update.message.text, False + else: + text, all_stops = update.callback_query.data, True + trip_id = text[1:] + results = source.get_stops_from_trip_id(trip_id, stop_times_filter.day) + + line = results[0].route_name + text = '' + format_date(stop_times_filter.day, 'EEEE d MMMM', locale=lang) + ' - ' + _( + 'line') + ' ' + line + f' {trip_id}' + + dep_stop_index = 0 + arr_stop_index = len(results) - 1 + if not all_stops: + dep_stop_ids = stop_times_filter.dep_stop_ids.split(',') + try: + dep_stop_index = next(i for i, v in enumerate(results) if v.station.id in dep_stop_ids) + except StopIteration: + logger.warning('No departure stop found') + if arr_cluster_name: + arr_stop_ids = stop_times_filter.arr_stop_ids.split(',') + try: + arr_stop_index = dep_stop_index + next( + i for i, v in enumerate(results[dep_stop_index:]) if str(v.station.id) in arr_stop_ids) + except StopIteration: + logger.warning('No arrival stop found') + + platform_text = _(f'{source.name}_platform') + + are_dep_and_arr_times_equal = all( + result.arr_time == result.dep_time for result in results[dep_stop_index:arr_stop_index + 1]) + + for i, result in enumerate(results[dep_stop_index:arr_stop_index + 1]): + arr_time = result.arr_time.strftime('%H:%M') if result.arr_time else '' + dep_time = result.dep_time.strftime('%H:%M') if result.dep_time else '' + + if are_dep_and_arr_times_equal: + text += f'\n{arr_time} {result.station.station.name}' + else: + if i == 0: + text += f'\n{result.station.station.name} {dep_time}' + elif i == arr_stop_index: + text += f'\n{arr_time} {result.station.station.name}' + else: + text += f'\n{arr_time} {result.station.station.name} {dep_time}' + + if result.platform: + text += f' ({platform_text} {result.platform})' + + buttons = [InlineKeyboardButton(_('back'), callback_data=context.user_data['query_data'])] + + if not all_stops: + buttons.append(InlineKeyboardButton(_('all_stops'), callback_data=f'M{trip_id}')) + + reply_markup = InlineKeyboardMarkup([buttons]) + if update.message: + await update.message.reply_text(text=text, reply_markup=reply_markup, parse_mode='HTML', + disable_notification=True) + else: + await update.callback_query.edit_message_text(text=text, reply_markup=reply_markup, parse_mode='HTML') + return SHOW_STOP + + +async def search_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + db_file: Source = thismodule.sources[context.user_data['transport_type']] + + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + try: + lines = db_file.search_lines(update.message.text) + except NotImplementedError: + await update.message.reply_text(_('not_implemented'), disable_notification=True) + return ConversationHandler.END + + keyboard = [[InlineKeyboardButton(line[2], callback_data=f'L{line[0]}/{line[1]}')] for line in lines] + inline_markup = InlineKeyboardMarkup(keyboard) + + await update.message.reply_text(_('choose_line'), reply_markup=inline_markup, disable_notification=True) + + return SHOW_LINE + + +async def show_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + source: Source = thismodule.sources[context.user_data['transport_type']] + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + query = update.callback_query + + trip_id, line = query.data[1:].split('/') + + day = date.today() + stops = source.get_stops_from_trip_id(trip_id, day) + + text = _('stops') + ':\n' + + inline_buttons = [] + + for stop in stops: + station = stop.station.station + stop_id = station.id + stop_name = station.name + inline_buttons.append([InlineKeyboardButton(stop_name, callback_data=f'S{stop_id}/{line}')]) + + await query.edit_message_text(text=text, reply_markup=InlineKeyboardMarkup(inline_buttons)) + await query.answer('') + + return SHOW_STOP + + +async def cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + clean_user_data(context) + + lang = 'it' if update.effective_user.language_code == 'it' else 'en' + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + + await update.message.reply_text(_('cancel') + "\n\n" + _('home') % (_('stop'), _('line')), + reply_markup=ReplyKeyboardRemove(), disable_notification=True) + + return ConversationHandler.END + + +async def setup(config, sources): + DEV = config.get('DEV', False) + + thismodule.sources = sources + + application = Application.builder().token(config['TOKEN']).persistence(persistence=thismodule.persistence).build() + + langs = [f for f in os.listdir(localedir) if os.path.isdir(os.path.join(localedir, f))] + default_lang = 'en' + + for lang in langs: + trans = gettext.translation('messages', localedir, languages=[lang]) + _ = trans.gettext + language_code = lang if lang != default_lang else '' + r = requests.post(f'https://api.telegram.org/bot{config["TOKEN"]}/setMyCommands', json={ + 'commands': [ + {'command': _('stop'), 'description': _('search_by_stop')}, + {'command': _('line'), 'description': _('search_by_line')} + ], + 'language_code': language_code + }) + + conv_handler = ConversationHandler( + name='orari', + entry_points=[MessageHandler(filters.Regex(r'^\/[a-z]+$'), choose_service)], + states={ + SEARCH_STOP: [ + MessageHandler((filters.TEXT | filters.LOCATION) & (~filters.COMMAND), search_stop) + ], + SPECIFY_LINE: [CallbackQueryHandler(specify_line, r'^T')], + SEARCH_LINE: [ + MessageHandler(filters.TEXT & (~filters.COMMAND), search_line), + CallbackQueryHandler(specify_line, r'^T') + ], + SHOW_LINE: [CallbackQueryHandler(show_line, r'^L')], + SHOW_STOP: [ + CallbackQueryHandler(filter_show_stop, r'^Q'), + MessageHandler(filters.Regex(r'^\/[0-9]+$'), trip_view), + CallbackQueryHandler(trip_view, r'^M'), + CallbackQueryHandler(show_stop_from_id, r'^S'), + MessageHandler(filters.Regex(r'^\-|\+1[a-z]$'), change_day_show_stop), + MessageHandler((filters.TEXT | filters.LOCATION) & (~filters.COMMAND), search_stop), + CallbackQueryHandler(search_stop, r'^F') + ] + }, + fallbacks=[CommandHandler("cancel", cancel), MessageHandler(filters.Regex(r'^\/[a-z]+$'), choose_service)], + persistent=True + ) + + application.add_handler(CommandHandler("start", start)) + application.add_handler(MessageHandler(filters.Regex(r'^\/announce '), announce)) + application.add_handler(conv_handler) + + webhook_url = config['WEBHOOK_URL'] + '/tg_bot_webhook' + bot: Bot = application.bot + + if DEV: + await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['SECRET_TOKEN']) + else: + await bot.set_webhook(webhook_url, secret_token=config['SECRET_TOKEN']) + + return application, get_routes(application) diff --git a/server/persistence.py b/tgbot/persistence.py similarity index 100% rename from server/persistence.py rename to tgbot/persistence.py diff --git a/tgbot/routes.py b/tgbot/routes.py new file mode 100644 index 0000000..fb46a8f --- /dev/null +++ b/tgbot/routes.py @@ -0,0 +1,22 @@ +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route +from telegram import Update + +from server.handlers import config + + +def get_routes(application): + async def telegram(request: Request) -> Response: + if request.headers['X-Telegram-Bot-Api-Secret-Token'] != config['SECRET_TOKEN']: + return Response(status_code=403) + await application.update_queue.put( + Update.de_json(data=await request.json(), bot=application.bot) + ) + return Response() + + routes = [ + Route("/tg_bot_webhook", telegram, methods=["POST"]) + ] + + return routes diff --git a/server/stop_times_filter.py b/tgbot/stop_times_filter.py similarity index 100% rename from server/stop_times_filter.py rename to tgbot/stop_times_filter.py From 201f7e06aea30ca1248e8b1e01d02290a69875f8 Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Mon, 30 Oct 2023 16:51:27 +0100 Subject: [PATCH 3/6] Complete separation --- alembic/env.py | 2 +- config.example.yaml | 9 +++-- run.py | 54 ++++++++++++++++++++++++++- save_data.py | 2 +- server/handlers.py | 89 --------------------------------------------- server/routes.py | 31 ++++++++++++++++ server/sources.py | 23 ++++++++++++ tgbot/handlers.py | 46 ++++++++++++++++------- tgbot/routes.py | 2 +- 9 files changed, 147 insertions(+), 111 deletions(-) delete mode 100644 server/handlers.py create mode 100644 server/routes.py create mode 100644 server/sources.py diff --git a/alembic/env.py b/alembic/env.py index 4d60f29..ba927c1 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -5,7 +5,7 @@ from sqlalchemy import pool from server.base.models import Base -from server.handlers import engine_url +from server.sources import engine_url # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/config.example.yaml b/config.example.yaml index ae28974..8a4b52f 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,13 +1,14 @@ -TOKEN: -WEBHOOK_URL: -SECRET_TOKEN: +TGBOT: # True or False (if True, the bot will be used) +TOKEN: # required if TGBOT is True +WEBHOOK_URL: # required if TGBOT is True +SECRET_TOKEN: # required if TGBOT is True DEV: # True or False PGUSER: PGPASSWORD: PGPORT: PGHOST: PGDATABASE: -ADMIN_TG_ID: # Telegram user ID of the admin +ADMIN_TG_ID: # Telegram user ID of the admin, required if TGBOT is True SSL_KEYFILE: # Path to the SSL key file SSL_CERTFILE: # Path to the SSL certificate file TYPESENSE_API_KEY: diff --git a/run.py b/run.py index 1d6281d..157476a 100644 --- a/run.py +++ b/run.py @@ -1,6 +1,56 @@ import asyncio +import logging -from server.handlers import main +import uvicorn +from starlette.applications import Starlette + +from config import config +from server.routes import routes as server_routes + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +async def run() -> None: + routes = server_routes + + tgbot_application = None + if config['TGBOT']: + from tgbot.handlers import set_up_application + tgbot_application = await set_up_application() + from tgbot.routes import get_routes as get_tgbot_routes + routes += get_tgbot_routes(tgbot_application) + + starlette_app = Starlette(routes=routes) + + if config.get('DEV', False): + webserver = uvicorn.Server( + config=uvicorn.Config( + app=starlette_app, + port=8000, + host="127.0.0.1", + ) + ) + else: + webserver = uvicorn.Server( + config=uvicorn.Config( + app=starlette_app, + port=443, + host="0.0.0.0", + ssl_keyfile=config['SSL_KEYFILE'], + ssl_certfile=config['SSL_CERTFILE'] + ) + ) + + if tgbot_application: + async with tgbot_application: + await tgbot_application.start() + await webserver.serve() + await tgbot_application.stop() + else: + await webserver.serve() if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(run()) diff --git a/save_data.py b/save_data.py index e86cc18..d471b54 100644 --- a/save_data.py +++ b/save_data.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import sessionmaker from server.GTFS import GTFS -from server.handlers import engine +from server.sources import engine from server.trenitalia import Trenitalia from server.typesense import connect_to_typesense diff --git a/server/handlers.py b/server/handlers.py deleted file mode 100644 index dbba95f..0000000 --- a/server/handlers.py +++ /dev/null @@ -1,89 +0,0 @@ -import logging - -import uvicorn -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Route - -from config import config -from tgbot.handlers import setup as tgbot_setup -from .GTFS import GTFS -from .trenitalia import Trenitalia -from .typesense import connect_to_typesense - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) -logger = logging.getLogger(__name__) - -engine_url = f"postgresql://{config['PGUSER']}:{config['PGPASSWORD']}@{config['PGHOST']}:{config['PGPORT']}/" \ - f"{config['PGDATABASE']}" -engine = create_engine(engine_url) - - -async def main() -> None: - DEV = config.get('DEV', False) - - session = sessionmaker(bind=engine)() - typesense = connect_to_typesense() - - sources = { - 'aut': GTFS('automobilistico', '🚌', session, typesense, dev=DEV), - 'nav': GTFS('navigazione', '⛴️', session, typesense, dev=DEV), - 'treni': Trenitalia(session, typesense) - } - - for source in sources.values(): - source.sync_stations_typesense(source.get_source_stations()) - - application, tgbot_routes = await tgbot_setup(config, sources) - - async def home(request: Request) -> Response: - text_response = '' - - try: - sources['treni'].session.execute(text('SELECT 1')) - except Exception: - return Response(status_code=500) - else: - text_response += '

Postgres connection OK

' - - text_response += '
    ' - for source in sources.values(): - if hasattr(source, 'gtfs_version'): - text_response += f'
  • {source.name}: GTFS v.{source.gtfs_version}
  • ' - else: - text_response += f'
  • {source.name}
  • ' - text_response += '
' - return Response(text_response) - - routes = [Route("/", home)] - routes += tgbot_routes - starlette_app = Starlette(routes=routes) - - if DEV: - webserver = uvicorn.Server( - config=uvicorn.Config( - app=starlette_app, - port=8000, - host="127.0.0.1", - ) - ) - else: - webserver = uvicorn.Server( - config=uvicorn.Config( - app=starlette_app, - port=443, - host="0.0.0.0", - ssl_keyfile=config['SSL_KEYFILE'], - ssl_certfile=config['SSL_CERTFILE'] - ) - ) - - async with application: - await application.start() - await webserver.serve() - await application.stop() diff --git a/server/routes.py b/server/routes.py new file mode 100644 index 0000000..77b552f --- /dev/null +++ b/server/routes.py @@ -0,0 +1,31 @@ +from sqlalchemy import text +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +from server.sources import sources + + +async def home(request: Request) -> Response: + text_response = '' + + try: + sources['treni'].session.execute(text('SELECT 1')) + except Exception: + return Response(status_code=500) + else: + text_response += '

Postgres connection OK

' + + text_response += '
    ' + for source in sources.values(): + if hasattr(source, 'gtfs_version'): + text_response += f'
  • {source.name}: GTFS v.{source.gtfs_version}
  • ' + else: + text_response += f'
  • {source.name}
  • ' + text_response += '
' + return Response(text_response) + + +routes = [ + Route("/", home) +] diff --git a/server/sources.py b/server/sources.py new file mode 100644 index 0000000..4b96293 --- /dev/null +++ b/server/sources.py @@ -0,0 +1,23 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from config import config +from server.GTFS import GTFS +from server.trenitalia import Trenitalia +from server.typesense import connect_to_typesense + +engine_url = f"postgresql://{config['PGUSER']}:{config['PGPASSWORD']}@{config['PGHOST']}:{config['PGPORT']}/" \ + f"{config['PGDATABASE']}" +engine = create_engine(engine_url) + +session = sessionmaker(bind=engine)() +typesense = connect_to_typesense() + +sources = { + 'aut': GTFS('automobilistico', '🚌', session, typesense, dev=config.get('DEV', False)), + 'nav': GTFS('navigazione', '⛴️', session, typesense, dev=config.get('DEV', False)), + 'treni': Trenitalia(session, typesense) +} + +for source in sources.values(): + source.sync_stations_typesense(source.get_source_stations()) diff --git a/tgbot/handlers.py b/tgbot/handlers.py index ed5c91b..2fba580 100644 --- a/tgbot/handlers.py +++ b/tgbot/handlers.py @@ -18,8 +18,31 @@ from config import config from server.base import Source +from server.sources import sources as defined_sources +from .persistence import SQLitePersistence +from .stop_times_filter import StopTimesFilter +import gettext +import logging +import os +import sys +from datetime import timedelta, datetime, date + +import requests +from babel.dates import format_date +from telegram import Update, KeyboardButton, ReplyKeyboardMarkup, InlineKeyboardButton, InlineKeyboardMarkup, \ + ReplyKeyboardRemove, Bot +from telegram.ext import ( + Application, + CommandHandler, + ConversationHandler, + MessageHandler, + filters, CallbackQueryHandler, ) +from telegram.ext import ContextTypes + +from config import config +from server.base import Source +from server.sources import sources as defined_sources from .persistence import SQLitePersistence -from .routes import get_routes from .stop_times_filter import StopTimesFilter logging.basicConfig( @@ -31,7 +54,7 @@ parent_dir = os.path.abspath(current_dir + "/../") thismodule = sys.modules[__name__] thismodule.sources = {} -thismodule.persistence = SQLitePersistence() +thismodule.persistence = None SEARCH_STOP, SPECIFY_LINE, SEARCH_LINE, SHOW_LINE, SHOW_STOP = range(5) @@ -512,12 +535,11 @@ async def cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: return ConversationHandler.END -async def setup(config, sources): - DEV = config.get('DEV', False) - - thismodule.sources = sources - - application = Application.builder().token(config['TOKEN']).persistence(persistence=thismodule.persistence).build() +async def set_up_application(): + persistence = SQLitePersistence() + application = Application.builder().token(config['TOKEN']).persistence(persistence=persistence).build() + thismodule.sources = defined_sources + thismodule.persistence = persistence langs = [f for f in os.listdir(localedir) if os.path.isdir(os.path.join(localedir, f))] default_lang = 'en' @@ -564,13 +586,11 @@ async def setup(config, sources): application.add_handler(CommandHandler("start", start)) application.add_handler(MessageHandler(filters.Regex(r'^\/announce '), announce)) application.add_handler(conv_handler) - - webhook_url = config['WEBHOOK_URL'] + '/tg_bot_webhook' bot: Bot = application.bot - - if DEV: + webhook_url = config['WEBHOOK_URL'] + '/tg_bot_webhook' + if config.get('DEV', False): await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['SECRET_TOKEN']) else: await bot.set_webhook(webhook_url, secret_token=config['SECRET_TOKEN']) - return application, get_routes(application) + return application diff --git a/tgbot/routes.py b/tgbot/routes.py index fb46a8f..1d6f573 100644 --- a/tgbot/routes.py +++ b/tgbot/routes.py @@ -3,7 +3,7 @@ from starlette.routing import Route from telegram import Update -from server.handlers import config +from config import config def get_routes(application): From 412f5d0229b83eebc0a2d1da57848445076d98d2 Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Mon, 30 Oct 2023 17:16:14 +0100 Subject: [PATCH 4/6] Add prefix TG_ to all telegram-related config variables --- config.example.yaml | 10 +++++----- run.py | 2 +- tgbot/handlers.py | 35 ++++++----------------------------- tgbot/routes.py | 2 +- 4 files changed, 13 insertions(+), 36 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 8a4b52f..76f2a01 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,14 +1,14 @@ -TGBOT: # True or False (if True, the bot will be used) -TOKEN: # required if TGBOT is True -WEBHOOK_URL: # required if TGBOT is True -SECRET_TOKEN: # required if TGBOT is True +TG_BOT_ENABLED: # True or False (if True, the bot will be used) +TG_TOKEN: # required if TG_BOT_ENABLED is True +TG_WEBHOOK_URL: # required if TG_BOT_ENABLED is True +TG_SECRET_TOKEN: # required if TG_BOT_ENABLED is True DEV: # True or False PGUSER: PGPASSWORD: PGPORT: PGHOST: PGDATABASE: -ADMIN_TG_ID: # Telegram user ID of the admin, required if TGBOT is True +TG_ADMIN_ID: # Telegram user ID of the admin, required if TG_BOT_ENABLED is True SSL_KEYFILE: # Path to the SSL key file SSL_CERTFILE: # Path to the SSL certificate file TYPESENSE_API_KEY: diff --git a/run.py b/run.py index 157476a..76f9802 100644 --- a/run.py +++ b/run.py @@ -17,7 +17,7 @@ async def run() -> None: routes = server_routes tgbot_application = None - if config['TGBOT']: + if config['TG_BOT_ENABLED']: from tgbot.handlers import set_up_application tgbot_application = await set_up_application() from tgbot.routes import get_routes as get_tgbot_routes diff --git a/tgbot/handlers.py b/tgbot/handlers.py index 2fba580..cb20dbc 100644 --- a/tgbot/handlers.py +++ b/tgbot/handlers.py @@ -16,29 +16,6 @@ filters, CallbackQueryHandler, ) from telegram.ext import ContextTypes -from config import config -from server.base import Source -from server.sources import sources as defined_sources -from .persistence import SQLitePersistence -from .stop_times_filter import StopTimesFilter -import gettext -import logging -import os -import sys -from datetime import timedelta, datetime, date - -import requests -from babel.dates import format_date -from telegram import Update, KeyboardButton, ReplyKeyboardMarkup, InlineKeyboardButton, InlineKeyboardMarkup, \ - ReplyKeyboardRemove, Bot -from telegram.ext import ( - Application, - CommandHandler, - ConversationHandler, - MessageHandler, - filters, CallbackQueryHandler, ) -from telegram.ext import ContextTypes - from config import config from server.base import Source from server.sources import sources as defined_sources @@ -82,7 +59,7 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def announce(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - if config.get('ADMIN_TG_ID') != update.effective_user.id: + if config.get('TG_ADMIN_ID') != update.effective_user.id: return persistence: SQLitePersistence = thismodule.persistence @@ -537,7 +514,7 @@ async def cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: async def set_up_application(): persistence = SQLitePersistence() - application = Application.builder().token(config['TOKEN']).persistence(persistence=persistence).build() + application = Application.builder().token(config['TG_TOKEN']).persistence(persistence=persistence).build() thismodule.sources = defined_sources thismodule.persistence = persistence @@ -548,7 +525,7 @@ async def set_up_application(): trans = gettext.translation('messages', localedir, languages=[lang]) _ = trans.gettext language_code = lang if lang != default_lang else '' - r = requests.post(f'https://api.telegram.org/bot{config["TOKEN"]}/setMyCommands', json={ + r = requests.post(f'https://api.telegram.org/bot{config["TG_TOKEN"]}/setMyCommands', json={ 'commands': [ {'command': _('stop'), 'description': _('search_by_stop')}, {'command': _('line'), 'description': _('search_by_line')} @@ -587,10 +564,10 @@ async def set_up_application(): application.add_handler(MessageHandler(filters.Regex(r'^\/announce '), announce)) application.add_handler(conv_handler) bot: Bot = application.bot - webhook_url = config['WEBHOOK_URL'] + '/tg_bot_webhook' + webhook_url = config['TG_WEBHOOK_URL'] + '/tg_bot_webhook' if config.get('DEV', False): - await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['SECRET_TOKEN']) + await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['TG_SECRET_TOKEN']) else: - await bot.set_webhook(webhook_url, secret_token=config['SECRET_TOKEN']) + await bot.set_webhook(webhook_url, secret_token=config['TG_SECRET_TOKEN']) return application diff --git a/tgbot/routes.py b/tgbot/routes.py index 1d6f573..3d02409 100644 --- a/tgbot/routes.py +++ b/tgbot/routes.py @@ -8,7 +8,7 @@ def get_routes(application): async def telegram(request: Request) -> Response: - if request.headers['X-Telegram-Bot-Api-Secret-Token'] != config['SECRET_TOKEN']: + if request.headers['X-Telegram-Bot-Api-Secret-Token'] != config['TG_SECRET_TOKEN']: return Response(status_code=403) await application.update_queue.put( Update.de_json(data=await request.json(), bot=application.bot) From 74c13dfd2bfe91ae05760fce4fc8bcd7d1ddc979 Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Mon, 30 Oct 2023 17:27:24 +0100 Subject: [PATCH 5/6] Update README.md --- README.md | 47 ++++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index cd63da6..8668753 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,36 @@ -MuoVErsi is a Telegram bot for advanced users of Venice, Italy's public transit. You can check it out -here: [@MuoVErsiBot](https://t.me/MuoVErsiBot). +MuoVErsi is a web service that parses and serves timetables of buses, trams, trains and waterbusses. As of now, it +supports Venice, Italy's public transit system (by using public GTFS files) and Trenitalia trains within 100km from +Venice (parsed from the Trenitalia api). However, since it can build on any GTFS file, it will be easily extended to +other cities in the future. -It allows you to get departure times for the next buses, trams and vaporetti (waterbusses) from a given stop or -location, or starting from a specific line. You can then use filters to get the right results and see all the -stops/times of that specific route. +Separated from the core code, there is a Telegram bot that uses the web service to provide a more user-friendly +interface. You can check it out here: [@MuoVErsiBot](https://t.me/MuoVErsiBot). Also, a mobile app is in the works. -## Infrastructure +## Features -The bot is written in Python 3 and uses -the [python-telegram-bot](https://github.com/python-telegram-bot/python-telegram-bot) library, both for interacting with -Telegram bot API and for the http server. -The program downloads the data from Venice transit agency Actv's GTFS files and stores it in SQLite databases, thanks to -the -[gtfs](https://www.npmjs.com/package/gtfs) CLI. New data is checked every time the -server service restarts, or every night at 4:00 AM with a cronjob. +MuoVErsi allows you to get departure times from a given stop or location, or starting from a specific line. You can then +use filters to get the right results and see all the stops/times of that specific route. -When new data arrives, stops are not simply stored in the database, but they are clustered by name and location. This -way it is easier to search for bus stations with more than one bus stop. For example, "Piazzale Roma" has 15 -different bus stops from the GTFS file, but they are all clustered together. +When new data is parsed and saved to the database, stops are not simply stored as-is, but they are clustered +by name and location. This way it is easier to search for bus stations with more than one bus stop. For example, +"Piazzale Roma" has 15 different bus stops from the GTFS file, but they are all clustered together. -The code is not written specifically for Venice, so it can be easily adapted to other cities that use GTFS files. +## Installation +### Requirements + +- Python 3 +- PostgreSQL for the database +- [Typesense](https://typesense.org/) for the stop search engine +- [Telegram bot token](https://core.telegram.org/bots/features#botfather) if you also want to run the bot + +### Steps + +1. Download the repo and install the dependencies with `pip install -r requirements.txt`. +2. Fill out the config file `config.example.yaml` and rename it to `config.yaml`. If you don't want to run the Telegram, + bot, set `TG_BOT_ENABLED` to `False` and skip the all the variables starting with `TG_`. You won't need the `tgbot` + folder. +3. Run PostgreSQL migrations with `alembic upgrade head`. +4. Run the server by executing `run.py`. For saving data from the GTFS files and, more importantly, for the parsing and + saving of Trenitalia trains, make sure you schedule the execution of `save_data.py` once a day. As of now, also + a daily restart of `run.py` is required to set the service calendar to the current day. From 64198c73f23efbd148eea449360bc4e7476417bb Mon Sep 17 00:00:00 2001 From: Giacomo Sarrocco Date: Mon, 30 Oct 2023 17:30:08 +0100 Subject: [PATCH 6/6] Clarify bot separation from server --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8668753..7f9d725 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ supports Venice, Italy's public transit system (by using public GTFS files) and Venice (parsed from the Trenitalia api). However, since it can build on any GTFS file, it will be easily extended to other cities in the future. -Separated from the core code, there is a Telegram bot that uses the web service to provide a more user-friendly +Separated from the core code and optional to set up, a Telegram bot uses the web service to provide a more user-friendly interface. You can check it out here: [@MuoVErsiBot](https://t.me/MuoVErsiBot). Also, a mobile app is in the works. ## Features