diff --git a/README.md b/README.md index cd63da6..7f9d725 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,36 @@ -MuoVErsi is a Telegram bot for advanced users of Venice, Italy's public transit. You can check it out -here: [@MuoVErsiBot](https://t.me/MuoVErsiBot). +MuoVErsi is a web service that parses and serves timetables of buses, trams, trains and waterbusses. As of now, it +supports Venice, Italy's public transit system (by using public GTFS files) and Trenitalia trains within 100km from +Venice (parsed from the Trenitalia api). However, since it can build on any GTFS file, it will be easily extended to +other cities in the future. -It allows you to get departure times for the next buses, trams and vaporetti (waterbusses) from a given stop or -location, or starting from a specific line. You can then use filters to get the right results and see all the -stops/times of that specific route. +Separated from the core code and optional to set up, a Telegram bot uses the web service to provide a more user-friendly +interface. You can check it out here: [@MuoVErsiBot](https://t.me/MuoVErsiBot). Also, a mobile app is in the works. -## Infrastructure +## Features -The bot is written in Python 3 and uses -the [python-telegram-bot](https://github.com/python-telegram-bot/python-telegram-bot) library, both for interacting with -Telegram bot API and for the http server. -The program downloads the data from Venice transit agency Actv's GTFS files and stores it in SQLite databases, thanks to -the -[gtfs](https://www.npmjs.com/package/gtfs) CLI. New data is checked every time the -server service restarts, or every night at 4:00 AM with a cronjob. +MuoVErsi allows you to get departure times from a given stop or location, or starting from a specific line. You can then +use filters to get the right results and see all the stops/times of that specific route. -When new data arrives, stops are not simply stored in the database, but they are clustered by name and location. This -way it is easier to search for bus stations with more than one bus stop. For example, "Piazzale Roma" has 15 -different bus stops from the GTFS file, but they are all clustered together. +When new data is parsed and saved to the database, stops are not simply stored as-is, but they are clustered +by name and location. This way it is easier to search for bus stations with more than one bus stop. For example, +"Piazzale Roma" has 15 different bus stops from the GTFS file, but they are all clustered together. -The code is not written specifically for Venice, so it can be easily adapted to other cities that use GTFS files. +## Installation +### Requirements + +- Python 3 +- PostgreSQL for the database +- [Typesense](https://typesense.org/) for the stop search engine +- [Telegram bot token](https://core.telegram.org/bots/features#botfather) if you also want to run the bot + +### Steps + +1. Download the repo and install the dependencies with `pip install -r requirements.txt`. +2. Fill out the config file `config.example.yaml` and rename it to `config.yaml`. If you don't want to run the Telegram, + bot, set `TG_BOT_ENABLED` to `False` and skip the all the variables starting with `TG_`. You won't need the `tgbot` + folder. +3. Run PostgreSQL migrations with `alembic upgrade head`. +4. Run the server by executing `run.py`. For saving data from the GTFS files and, more importantly, for the parsing and + saving of Trenitalia trains, make sure you schedule the execution of `save_data.py` once a day. As of now, also + a daily restart of `run.py` is required to set the service calendar to the current day. diff --git a/alembic/env.py b/alembic/env.py index 2c4805a..ba927c1 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,8 +4,8 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool -from MuoVErsi.base.models import Base -from MuoVErsi.handlers import engine_url +from server.base.models import Base +from server.sources import engine_url # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py b/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py index 515ba73..5797f94 100644 --- a/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py +++ b/alembic/versions/2e8b9b6298f0_create_typesense_table_for_stations.py @@ -7,7 +7,7 @@ """ from typesense.exceptions import ObjectNotFound -from MuoVErsi.typesense import connect_to_typesense +from server.typesense import connect_to_typesense # revision identifiers, used by Alembic. revision = '2e8b9b6298f0' diff --git a/alembic/versions/6c9ef3a680e3_create_stops_table.py b/alembic/versions/6c9ef3a680e3_create_stops_table.py index 3fec4a7..de5dcf9 100644 --- a/alembic/versions/6c9ef3a680e3_create_stops_table.py +++ b/alembic/versions/6c9ef3a680e3_create_stops_table.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy.orm import sessionmaker -from MuoVErsi.base import Station +from server.base import Station # revision identifiers, used by Alembic. revision = '6c9ef3a680e3' diff --git a/config.example.yaml b/config.example.yaml index ae28974..76f2a01 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,13 +1,14 @@ -TOKEN: -WEBHOOK_URL: -SECRET_TOKEN: +TG_BOT_ENABLED: # True or False (if True, the bot will be used) +TG_TOKEN: # required if TG_BOT_ENABLED is True +TG_WEBHOOK_URL: # required if TG_BOT_ENABLED is True +TG_SECRET_TOKEN: # required if TG_BOT_ENABLED is True DEV: # True or False PGUSER: PGPASSWORD: PGPORT: PGHOST: PGDATABASE: -ADMIN_TG_ID: # Telegram user ID of the admin +TG_ADMIN_ID: # Telegram user ID of the admin, required if TG_BOT_ENABLED is True SSL_KEYFILE: # Path to the SSL key file SSL_CERTFILE: # Path to the SSL certificate file TYPESENSE_API_KEY: diff --git a/config.py b/config.py new file mode 100644 index 0000000..dae42cf --- /dev/null +++ b/config.py @@ -0,0 +1,19 @@ +import logging +import os + +import yaml + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + +current_dir = os.path.abspath(os.path.dirname(__file__)) + +config_path = os.path.join(current_dir, 'config.yaml') +with open(config_path, 'r') as config_file: + try: + config = yaml.safe_load(config_file) + logger.info(config) + except yaml.YAMLError as err: + logger.error(err) diff --git a/run.py b/run.py index 671bca1..76f9802 100644 --- a/run.py +++ b/run.py @@ -1,6 +1,56 @@ import asyncio +import logging -from MuoVErsi.handlers import main +import uvicorn +from starlette.applications import Starlette + +from config import config +from server.routes import routes as server_routes + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +async def run() -> None: + routes = server_routes + + tgbot_application = None + if config['TG_BOT_ENABLED']: + from tgbot.handlers import set_up_application + tgbot_application = await set_up_application() + from tgbot.routes import get_routes as get_tgbot_routes + routes += get_tgbot_routes(tgbot_application) + + starlette_app = Starlette(routes=routes) + + if config.get('DEV', False): + webserver = uvicorn.Server( + config=uvicorn.Config( + app=starlette_app, + port=8000, + host="127.0.0.1", + ) + ) + else: + webserver = uvicorn.Server( + config=uvicorn.Config( + app=starlette_app, + port=443, + host="0.0.0.0", + ssl_keyfile=config['SSL_KEYFILE'], + ssl_certfile=config['SSL_CERTFILE'] + ) + ) + + if tgbot_application: + async with tgbot_application: + await tgbot_application.start() + await webserver.serve() + await tgbot_application.stop() + else: + await webserver.serve() if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(run()) diff --git a/save_data.py b/save_data.py index 9886c9c..d471b54 100644 --- a/save_data.py +++ b/save_data.py @@ -2,10 +2,10 @@ from sqlalchemy.orm import sessionmaker -from MuoVErsi.handlers import engine -from MuoVErsi.trenitalia import Trenitalia -from MuoVErsi.GTFS import GTFS -from MuoVErsi.typesense import connect_to_typesense +from server.GTFS import GTFS +from server.sources import engine +from server.trenitalia import Trenitalia +from server.typesense import connect_to_typesense logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/MuoVErsi/GTFS/__init__.py b/server/GTFS/__init__.py similarity index 100% rename from MuoVErsi/GTFS/__init__.py rename to server/GTFS/__init__.py diff --git a/MuoVErsi/GTFS/clustering.py b/server/GTFS/clustering.py similarity index 100% rename from MuoVErsi/GTFS/clustering.py rename to server/GTFS/clustering.py diff --git a/MuoVErsi/GTFS/models.py b/server/GTFS/models.py similarity index 100% rename from MuoVErsi/GTFS/models.py rename to server/GTFS/models.py diff --git a/MuoVErsi/GTFS/source.py b/server/GTFS/source.py similarity index 99% rename from MuoVErsi/GTFS/source.py rename to server/GTFS/source.py index 328f346..a3cfb6f 100644 --- a/MuoVErsi/GTFS/source.py +++ b/server/GTFS/source.py @@ -12,14 +12,13 @@ import requests from bs4 import BeautifulSoup +from sqlalchemy import select, func from tqdm import tqdm -from MuoVErsi.base import Source, Station, Stop, TripStopTime +from server.base import Source, Station, Stop, TripStopTime from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster from .models import CStop -from sqlalchemy import select, func - logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -310,7 +309,7 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim def search_lines(self, name): today = date.today() - from MuoVErsi.base import Trip + from server.base import Trip trips = self.session.execute( select(func.max(Trip.number), Trip.dest_text)\ .filter(Trip.orig_dep_date == today)\ diff --git a/MuoVErsi/__init__.py b/server/__init__.py similarity index 100% rename from MuoVErsi/__init__.py rename to server/__init__.py diff --git a/MuoVErsi/base/__init__.py b/server/base/__init__.py similarity index 100% rename from MuoVErsi/base/__init__.py rename to server/base/__init__.py diff --git a/MuoVErsi/base/models.py b/server/base/models.py similarity index 100% rename from MuoVErsi/base/models.py rename to server/base/models.py diff --git a/MuoVErsi/base/source.py b/server/base/source.py similarity index 99% rename from MuoVErsi/base/source.py rename to server/base/source.py index 725d29c..1204e3a 100644 --- a/MuoVErsi/base/source.py +++ b/server/base/source.py @@ -366,7 +366,7 @@ def get_stop_times_between_stops(self, dep_station: Station, arr_station: Statio raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name, arr_time=a_arr_time, orig_dep_date=raw_stop_time.orig_dep_date) - from MuoVErsi.trenitalia import TrenitaliaRoute + from server.trenitalia import TrenitaliaRoute route = TrenitaliaRoute(d_stop_time, a_stop_time) directions.append(Direction([route])) diff --git a/MuoVErsi/data/trenitalia_stations.json b/server/data/trenitalia_stations.json similarity index 100% rename from MuoVErsi/data/trenitalia_stations.json rename to server/data/trenitalia_stations.json diff --git a/MuoVErsi/helpers.py b/server/helpers.py similarity index 100% rename from MuoVErsi/helpers.py rename to server/helpers.py diff --git a/server/routes.py b/server/routes.py new file mode 100644 index 0000000..77b552f --- /dev/null +++ b/server/routes.py @@ -0,0 +1,31 @@ +from sqlalchemy import text +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +from server.sources import sources + + +async def home(request: Request) -> Response: + text_response = '' + + try: + sources['treni'].session.execute(text('SELECT 1')) + except Exception: + return Response(status_code=500) + else: + text_response += '

Postgres connection OK

' + + text_response += '' + return Response(text_response) + + +routes = [ + Route("/", home) +] diff --git a/server/sources.py b/server/sources.py new file mode 100644 index 0000000..4b96293 --- /dev/null +++ b/server/sources.py @@ -0,0 +1,23 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from config import config +from server.GTFS import GTFS +from server.trenitalia import Trenitalia +from server.typesense import connect_to_typesense + +engine_url = f"postgresql://{config['PGUSER']}:{config['PGPASSWORD']}@{config['PGHOST']}:{config['PGPORT']}/" \ + f"{config['PGDATABASE']}" +engine = create_engine(engine_url) + +session = sessionmaker(bind=engine)() +typesense = connect_to_typesense() + +sources = { + 'aut': GTFS('automobilistico', '🚌', session, typesense, dev=config.get('DEV', False)), + 'nav': GTFS('navigazione', '⛴️', session, typesense, dev=config.get('DEV', False)), + 'treni': Trenitalia(session, typesense) +} + +for source in sources.values(): + source.sync_stations_typesense(source.get_source_stations()) diff --git a/MuoVErsi/trenitalia/__init__.py b/server/trenitalia/__init__.py similarity index 100% rename from MuoVErsi/trenitalia/__init__.py rename to server/trenitalia/__init__.py diff --git a/MuoVErsi/trenitalia/source.py b/server/trenitalia/source.py similarity index 99% rename from MuoVErsi/trenitalia/source.py rename to server/trenitalia/source.py index cfb7c32..859666e 100644 --- a/MuoVErsi/trenitalia/source.py +++ b/server/trenitalia/source.py @@ -1,12 +1,12 @@ import json import math import os -from pytz import timezone import requests +from pytz import timezone from tqdm import tqdm -from MuoVErsi.base import * +from server.base import * logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/MuoVErsi/typesense.py b/server/typesense.py similarity index 100% rename from MuoVErsi/typesense.py rename to server/typesense.py diff --git a/tests/test_db.py b/tests/test_db.py index 95e622d..5105171 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,7 +1,8 @@ from datetime import date, datetime, time + import pytest -from MuoVErsi.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop +from server.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop @pytest.fixture diff --git a/tests/test_gtfs_clustering.py b/tests/test_gtfs_clustering.py index 343d9c5..9fb62b9 100644 --- a/tests/test_gtfs_clustering.py +++ b/tests/test_gtfs_clustering.py @@ -1,6 +1,6 @@ import pytest -from MuoVErsi.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster +from server.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster @pytest.fixture diff --git a/tgbot/__init__.py b/tgbot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MuoVErsi/handlers.py b/tgbot/handlers.py similarity index 83% rename from MuoVErsi/handlers.py rename to tgbot/handlers.py index d66f7d8..cb20dbc 100644 --- a/MuoVErsi/handlers.py +++ b/tgbot/handlers.py @@ -2,34 +2,25 @@ import logging import os import sys -from datetime import datetime, timedelta, date +from datetime import timedelta, datetime, date import requests -import uvicorn -import yaml from babel.dates import format_date -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Route -from telegram import ReplyKeyboardMarkup, ReplyKeyboardRemove, Update, KeyboardButton, InlineKeyboardMarkup, \ - InlineKeyboardButton, Bot +from telegram import Update, KeyboardButton, ReplyKeyboardMarkup, InlineKeyboardButton, InlineKeyboardMarkup, \ + ReplyKeyboardRemove, Bot from telegram.ext import ( Application, CommandHandler, - ContextTypes, ConversationHandler, MessageHandler, filters, CallbackQueryHandler, ) +from telegram.ext import ContextTypes +from config import config +from server.base import Source +from server.sources import sources as defined_sources from .persistence import SQLitePersistence -from .GTFS import GTFS -from .base import Source -from .trenitalia import Trenitalia from .stop_times_filter import StopTimesFilter -from .typesense import connect_to_typesense logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -40,26 +31,12 @@ parent_dir = os.path.abspath(current_dir + "/../") thismodule = sys.modules[__name__] thismodule.sources = {} -thismodule.persistence = SQLitePersistence() - +thismodule.persistence = None SEARCH_STOP, SPECIFY_LINE, SEARCH_LINE, SHOW_LINE, SHOW_STOP = range(5) - localedir = os.path.join(parent_dir, 'locales') -config_path = os.path.join(parent_dir, 'config.yaml') -with open(config_path, 'r') as config_file: - try: - config = yaml.safe_load(config_file) - logger.info(config) - except yaml.YAMLError as err: - logger.error(err) - -engine_url = f"postgresql://{config['PGUSER']}:{config['PGPASSWORD']}@{config['PGHOST']}:{config['PGPORT']}/" \ - f"{config['PGDATABASE']}" -engine = create_engine(engine_url) - def clean_user_data(context, keep_transport_type=True): context.user_data.pop('query_data', None) @@ -77,11 +54,12 @@ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: trans = gettext.translation('messages', localedir, languages=[lang]) _ = trans.gettext clean_user_data(context, False) - await update.message.reply_text(_('welcome') + "\n\n" + _('home') % (_('stop'), _('line')), disable_notification=True) + await update.message.reply_text(_('welcome') + "\n\n" + _('home') % (_('stop'), _('line')), + disable_notification=True) async def announce(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - if config.get('ADMIN_TG_ID') != update.effective_user.id: + if config.get('TG_ADMIN_ID') != update.effective_user.id: return persistence: SQLitePersistence = thismodule.persistence @@ -115,7 +93,8 @@ async def choose_service(update: Update, context: ContextTypes.DEFAULT_TYPE) -> reply_keyboard_markup = ReplyKeyboardMarkup( reply_keyboard, resize_keyboard=True, is_persistent=True ) - await update.message.reply_text(_('insert_stop'), reply_markup=reply_keyboard_markup, parse_mode='HTML', disable_notification=True) + await update.message.reply_text(_('insert_stop'), reply_markup=reply_keyboard_markup, parse_mode='HTML', + disable_notification=True) return SEARCH_STOP if context.user_data.get('transport_type'): @@ -170,7 +149,8 @@ async def specify_line(update: Update, context: ContextTypes.DEFAULT_TYPE) -> in await update.callback_query.answer() await update.callback_query.edit_message_text(_('service_selected') % transport_type, reply_markup=keyboard) else: - await bot.send_message(chat_id, _('service_selected') % transport_type, reply_markup=keyboard, disable_notification=True) + await bot.send_message(chat_id, _('service_selected') % transport_type, reply_markup=keyboard, + disable_notification=True) if send_second_message: await bot.send_message(chat_id, _('insert_line'), reply_markup=ReplyKeyboardRemove(), disable_notification=True) @@ -204,15 +184,17 @@ async def search_stop(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int if lat == '' and lon == '': stops_clusters, count = db_file.search_stops(name=text, all_sources=saved_dep_stop_ids, page=page, limit=limit) else: - stops_clusters, count = db_file.search_stops(lat=lat, lon=lon, all_sources=saved_dep_stop_ids, page=page, limit=limit) + stops_clusters, count = db_file.search_stops(lat=lat, lon=lon, all_sources=saved_dep_stop_ids, page=page, + limit=limit) if not stops_clusters: await update.message.reply_text(_('stop_not_found'), disable_notification=True) return SEARCH_STOP - buttons = [[InlineKeyboardButton(f'{cluster.name} {thismodule.sources[cluster.source].emoji}', callback_data=f'S{cluster.id}-{cluster.source}')] + buttons = [[InlineKeyboardButton(f'{cluster.name} {thismodule.sources[cluster.source].emoji}', + callback_data=f'S{cluster.id}-{cluster.source}')] for cluster in stops_clusters] - + paging_buttons = [] if page > 1: paging_buttons.append(InlineKeyboardButton('<', callback_data=f'F{text}/{lat}/{lon}/{page - 1}')) @@ -264,7 +246,8 @@ async def send_stop_times(_, lang, db_file: Source, stop_times_filter: StopTimes if message_id: await bot.edit_message_text(text, chat_id, message_id, reply_markup=reply_markup, parse_mode='HTML') else: - await bot.send_message(chat_id, text=text, reply_markup=reply_markup, parse_mode='HTML', disable_notification=True) + await bot.send_message(chat_id, text=text, reply_markup=reply_markup, parse_mode='HTML', + disable_notification=True) if stop_times_filter.first_time: if stop_times_filter.arr_stop_ids: @@ -413,7 +396,7 @@ async def trip_view(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: line = results[0].route_name text = '' + format_date(stop_times_filter.day, 'EEEE d MMMM', locale=lang) + ' - ' + _( 'line') + ' ' + line + f' {trip_id}' - + dep_stop_index = 0 arr_stop_index = len(results) - 1 if not all_stops: @@ -459,7 +442,8 @@ async def trip_view(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: reply_markup = InlineKeyboardMarkup([buttons]) if update.message: - await update.message.reply_text(text=text, reply_markup=reply_markup, parse_mode='HTML', disable_notification=True) + await update.message.reply_text(text=text, reply_markup=reply_markup, parse_mode='HTML', + disable_notification=True) else: await update.callback_query.edit_message_text(text=text, reply_markup=reply_markup, parse_mode='HTML') return SHOW_STOP @@ -522,27 +506,17 @@ async def cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: trans = gettext.translation('messages', localedir, languages=[lang]) _ = trans.gettext - await update.message.reply_text(_('cancel') + "\n\n" + _('home') % (_('stop'), _('line')), reply_markup=ReplyKeyboardRemove(), disable_notification=True) + await update.message.reply_text(_('cancel') + "\n\n" + _('home') % (_('stop'), _('line')), + reply_markup=ReplyKeyboardRemove(), disable_notification=True) return ConversationHandler.END -async def main() -> None: - DEV = config.get('DEV', False) - - session = sessionmaker(bind=engine)() - typesense = connect_to_typesense() - - thismodule.sources = { - 'aut': GTFS('automobilistico', '🚌', session, typesense, dev=DEV), - 'nav': GTFS('navigazione', '⛴️', session, typesense, dev=DEV), - 'treni': Trenitalia(session, typesense) - } - - for source in thismodule.sources.values(): - source.sync_stations_typesense(source.get_source_stations()) - - application = Application.builder().token(config['TOKEN']).persistence(persistence=thismodule.persistence).build() +async def set_up_application(): + persistence = SQLitePersistence() + application = Application.builder().token(config['TG_TOKEN']).persistence(persistence=persistence).build() + thismodule.sources = defined_sources + thismodule.persistence = persistence langs = [f for f in os.listdir(localedir) if os.path.isdir(os.path.join(localedir, f))] default_lang = 'en' @@ -551,7 +525,7 @@ async def main() -> None: trans = gettext.translation('messages', localedir, languages=[lang]) _ = trans.gettext language_code = lang if lang != default_lang else '' - r = requests.post(f'https://api.telegram.org/bot{config["TOKEN"]}/setMyCommands', json={ + r = requests.post(f'https://api.telegram.org/bot{config["TG_TOKEN"]}/setMyCommands', json={ 'commands': [ {'command': _('stop'), 'description': _('search_by_stop')}, {'command': _('line'), 'description': _('search_by_line')} @@ -589,71 +563,11 @@ async def main() -> None: application.add_handler(CommandHandler("start", start)) application.add_handler(MessageHandler(filters.Regex(r'^\/announce '), announce)) application.add_handler(conv_handler) - - webhook_url = config['WEBHOOK_URL'] + '/tg_bot_webhook' bot: Bot = application.bot - - if DEV: - await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['SECRET_TOKEN']) + webhook_url = config['TG_WEBHOOK_URL'] + '/tg_bot_webhook' + if config.get('DEV', False): + await bot.set_webhook(webhook_url, os.path.join(parent_dir, 'cert.pem'), secret_token=config['TG_SECRET_TOKEN']) else: - await bot.set_webhook(webhook_url, secret_token=config['SECRET_TOKEN']) - - async def telegram(request: Request) -> Response: - if request.headers['X-Telegram-Bot-Api-Secret-Token'] != config['SECRET_TOKEN']: - return Response(status_code=403) - await application.update_queue.put( - Update.de_json(data=await request.json(), bot=application.bot) - ) - return Response() - - async def home(request: Request) -> Response: - sources = thismodule.sources - text_response = '' - - try: - sources['treni'].session.execute(text('SELECT 1')) - except Exception: - return Response(status_code=500) - else: - text_response += '

Postgres connection OK

' - - - text_response += '' - return Response(text_response) - - starlette_app = Starlette( - routes=[ - Route("/", home), - Route("/tg_bot_webhook", telegram, methods=["POST"]) - ] - ) - - if DEV: - webserver = uvicorn.Server( - config=uvicorn.Config( - app=starlette_app, - port=8000, - host="127.0.0.1", - ) - ) - else: - webserver = uvicorn.Server( - config=uvicorn.Config( - app=starlette_app, - port=443, - host="0.0.0.0", - ssl_keyfile=config['SSL_KEYFILE'], - ssl_certfile=config['SSL_CERTFILE'] - ) - ) + await bot.set_webhook(webhook_url, secret_token=config['TG_SECRET_TOKEN']) - async with application: - await application.start() - await webserver.serve() - await application.stop() + return application diff --git a/MuoVErsi/persistence.py b/tgbot/persistence.py similarity index 100% rename from MuoVErsi/persistence.py rename to tgbot/persistence.py diff --git a/tgbot/routes.py b/tgbot/routes.py new file mode 100644 index 0000000..3d02409 --- /dev/null +++ b/tgbot/routes.py @@ -0,0 +1,22 @@ +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route +from telegram import Update + +from config import config + + +def get_routes(application): + async def telegram(request: Request) -> Response: + if request.headers['X-Telegram-Bot-Api-Secret-Token'] != config['TG_SECRET_TOKEN']: + return Response(status_code=403) + await application.update_queue.put( + Update.de_json(data=await request.json(), bot=application.bot) + ) + return Response() + + routes = [ + Route("/tg_bot_webhook", telegram, methods=["POST"]) + ] + + return routes diff --git a/MuoVErsi/stop_times_filter.py b/tgbot/stop_times_filter.py similarity index 99% rename from MuoVErsi/stop_times_filter.py rename to tgbot/stop_times_filter.py index c192f55..73da213 100644 --- a/MuoVErsi/stop_times_filter.py +++ b/tgbot/stop_times_filter.py @@ -5,7 +5,7 @@ from telegram import InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes -from MuoVErsi.base import Source, Liner, Station +from server.base import Source, Liner, Station logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO