diff --git a/MuoVErsi/sources/GTFS/__init__.py b/MuoVErsi/GTFS/__init__.py similarity index 100% rename from MuoVErsi/sources/GTFS/__init__.py rename to MuoVErsi/GTFS/__init__.py diff --git a/MuoVErsi/sources/GTFS/clustering.py b/MuoVErsi/GTFS/clustering.py similarity index 100% rename from MuoVErsi/sources/GTFS/clustering.py rename to MuoVErsi/GTFS/clustering.py diff --git a/MuoVErsi/sources/GTFS/models.py b/MuoVErsi/GTFS/models.py similarity index 100% rename from MuoVErsi/sources/GTFS/models.py rename to MuoVErsi/GTFS/models.py diff --git a/MuoVErsi/sources/GTFS/source.py b/MuoVErsi/GTFS/source.py similarity index 97% rename from MuoVErsi/sources/GTFS/source.py rename to MuoVErsi/GTFS/source.py index 5a7946c..328f346 100644 --- a/MuoVErsi/sources/GTFS/source.py +++ b/MuoVErsi/GTFS/source.py @@ -12,14 +12,13 @@ import requests from bs4 import BeautifulSoup -from telegram.ext import ContextTypes from tqdm import tqdm -from MuoVErsi.sources.base import Source, BaseStopTime, Route, Direction, Station, Stop, TripStopTime +from MuoVErsi.base import Source, Station, Stop, TripStopTime from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster from .models import CStop -from sqlalchemy import or_, select, func +from sqlalchemy import select, func logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -79,7 +78,7 @@ def __init__(self, transport_type, emoji, session, typesense, gtfs_versions_rang def file_path(self, ext, gtfs_version): current_dir = os.path.abspath(os.path.dirname(__file__)) - parent_dir = os.path.abspath(current_dir + f"/../../../{self.location}") + parent_dir = os.path.abspath(current_dir + f"/../../{self.location}") return os.path.join(parent_dir, f'{self.transport_type}_{gtfs_version}.{ext}') @@ -311,7 +310,7 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim def search_lines(self, name): today = date.today() - from MuoVErsi.sources.base import Trip + from MuoVErsi.base import Trip trips = self.session.execute( select(func.max(Trip.number), Trip.dest_text)\ .filter(Trip.orig_dep_date == today)\ diff --git a/MuoVErsi/base/__init__.py b/MuoVErsi/base/__init__.py new file mode 100644 index 0000000..3caf421 --- /dev/null +++ b/MuoVErsi/base/__init__.py @@ -0,0 +1 @@ +from .source import * diff --git a/MuoVErsi/base/models.py b/MuoVErsi/base/models.py new file mode 100644 index 0000000..475fa49 --- /dev/null +++ b/MuoVErsi/base/models.py @@ -0,0 +1,63 @@ +from datetime import date, datetime +from typing import Optional + +from sqlalchemy import ForeignKey, UniqueConstraint +from sqlalchemy.orm import declarative_base, Mapped, mapped_column, relationship + +Base = declarative_base() + + +class Station(Base): + __tablename__ = 'stations' + + id: Mapped[str] = mapped_column(primary_key=True) + name: Mapped[str] + lat: Mapped[Optional[float]] + lon: Mapped[Optional[float]] + ids: Mapped[str] = mapped_column(server_default='') + times_count: Mapped[float] = mapped_column(server_default='0') + source: Mapped[str] = mapped_column(server_default='treni') + stops = relationship('Stop', back_populates='station', cascade='all, delete-orphan') + + +class Stop(Base): + __tablename__ = 'stops' + + id: Mapped[str] = mapped_column(primary_key=True) + platform: Mapped[Optional[str]] + lat: Mapped[float] + lon: Mapped[float] + station_id: Mapped[str] = mapped_column(ForeignKey('stations.id')) + station: Mapped[Station] = relationship('Station', back_populates='stops') + source: Mapped[Optional[str]] + stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan') + + +class Trip(Base): + __tablename__ = 'trips' + + id: Mapped[int] = mapped_column(primary_key=True) + orig_id: Mapped[str] + dest_text: Mapped[str] + number: Mapped[int] + orig_dep_date: Mapped[date] + route_name: Mapped[str] + source: Mapped[str] = mapped_column(server_default='treni') + stop_times = relationship('StopTime', back_populates='trip', cascade='all, delete-orphan', passive_deletes=True) + + __table_args__ = (UniqueConstraint('source', 'number', 'orig_dep_date'),) + + +class StopTime(Base): + __tablename__ = 'stop_times' + + id: Mapped[int] = mapped_column(primary_key=True) + trip_id: Mapped[int] = mapped_column(ForeignKey('trips.id', ondelete='CASCADE')) + trip: Mapped[Trip] = relationship('Trip', back_populates='stop_times') + stop_id: Mapped[str] = mapped_column(ForeignKey('stops.id')) + stop: Mapped[Stop] = relationship('Stop', back_populates='stop_times') + sched_arr_dt: Mapped[Optional[datetime]] + sched_dep_dt: Mapped[Optional[datetime]] + platform: Mapped[Optional[str]] + + __table_args__ = (UniqueConstraint('trip_id', 'stop_id'),) diff --git a/MuoVErsi/sources/base.py b/MuoVErsi/base/source.py similarity index 89% rename from MuoVErsi/sources/base.py rename to MuoVErsi/base/source.py index 0f06397..725d29c 100644 --- a/MuoVErsi/sources/base.py +++ b/MuoVErsi/base/source.py @@ -1,12 +1,13 @@ import logging from datetime import datetime, date, timedelta, time -from typing import Optional -from sqlalchemy import select, ForeignKey, UniqueConstraint, func, and_ +from sqlalchemy import select, func, and_ from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Mapped, mapped_column, relationship, declarative_base, aliased +from sqlalchemy.orm import aliased from telegram.ext import ContextTypes +from .models import Station, Stop, Trip, StopTime + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -151,34 +152,6 @@ def format(self, number, _, source_name): return text -Base = declarative_base() - - -class Station(Base): - __tablename__ = 'stations' - - id: Mapped[str] = mapped_column(primary_key=True) - name: Mapped[str] - lat: Mapped[Optional[float]] - lon: Mapped[Optional[float]] - ids: Mapped[str] = mapped_column(server_default='') - times_count: Mapped[float] = mapped_column(server_default='0') - source: Mapped[str] = mapped_column(server_default='treni') - stops = relationship('Stop', back_populates='station', cascade='all, delete-orphan') - - -class Stop(Base): - __tablename__ = 'stops' - - id: Mapped[str] = mapped_column(primary_key=True) - platform: Mapped[Optional[str]] - lat: Mapped[float] - lon: Mapped[float] - station_id: Mapped[str] = mapped_column(ForeignKey('stations.id')) - station: Mapped[Station] = relationship('Station', back_populates='stops') - source: Mapped[Optional[str]] - stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan') - class TripStopTime(BaseStopTime): def __init__(self, station: Station, origin_id, dep_time: datetime | None, stop_sequence, delay: int, platform, @@ -393,7 +366,7 @@ def get_stop_times_between_stops(self, dep_station: Station, arr_station: Statio raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name, arr_time=a_arr_time, orig_dep_date=raw_stop_time.orig_dep_date) - from MuoVErsi.sources.trenitalia import TrenitaliaRoute + from MuoVErsi.trenitalia import TrenitaliaRoute route = TrenitaliaRoute(d_stop_time, a_stop_time) directions.append(Direction([route])) @@ -537,31 +510,3 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: return stop_times -class Trip(Base): - __tablename__ = 'trips' - - id: Mapped[int] = mapped_column(primary_key=True) - orig_id: Mapped[str] - dest_text: Mapped[str] - number: Mapped[int] - orig_dep_date: Mapped[date] - route_name: Mapped[str] - source: Mapped[str] = mapped_column(server_default='treni') - stop_times = relationship('StopTime', back_populates='trip', cascade='all, delete-orphan', passive_deletes=True) - - __table_args__ = (UniqueConstraint('source', 'number', 'orig_dep_date'),) - - -class StopTime(Base): - __tablename__ = 'stop_times' - - id: Mapped[int] = mapped_column(primary_key=True) - trip_id: Mapped[int] = mapped_column(ForeignKey('trips.id', ondelete='CASCADE')) - trip: Mapped[Trip] = relationship('Trip', back_populates='stop_times') - stop_id: Mapped[str] = mapped_column(ForeignKey('stops.id')) - stop: Mapped[Stop] = relationship('Stop', back_populates='stop_times') - sched_arr_dt: Mapped[Optional[datetime]] - sched_dep_dt: Mapped[Optional[datetime]] - platform: Mapped[Optional[str]] - - __table_args__ = (UniqueConstraint('trip_id', 'stop_id'),) diff --git a/MuoVErsi/sources/data/trenitalia_stations.json b/MuoVErsi/data/trenitalia_stations.json similarity index 100% rename from MuoVErsi/sources/data/trenitalia_stations.json rename to MuoVErsi/data/trenitalia_stations.json diff --git a/MuoVErsi/handlers.py b/MuoVErsi/handlers.py index 41b4611..d66f7d8 100644 --- a/MuoVErsi/handlers.py +++ b/MuoVErsi/handlers.py @@ -25,9 +25,9 @@ filters, CallbackQueryHandler, ) from .persistence import SQLitePersistence -from .sources.GTFS import GTFS -from .sources.base import Source -from .sources.trenitalia import Trenitalia +from .GTFS import GTFS +from .base import Source +from .trenitalia import Trenitalia from .stop_times_filter import StopTimesFilter from .typesense import connect_to_typesense diff --git a/MuoVErsi/stop_times_filter.py b/MuoVErsi/stop_times_filter.py index 3ba00b3..c192f55 100644 --- a/MuoVErsi/stop_times_filter.py +++ b/MuoVErsi/stop_times_filter.py @@ -5,7 +5,7 @@ from telegram import InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes -from MuoVErsi.sources.base import Source, Liner, Station +from MuoVErsi.base import Source, Liner, Station logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/MuoVErsi/trenitalia/__init__.py b/MuoVErsi/trenitalia/__init__.py new file mode 100644 index 0000000..fccc2dc --- /dev/null +++ b/MuoVErsi/trenitalia/__init__.py @@ -0,0 +1 @@ +from .source import * \ No newline at end of file diff --git a/MuoVErsi/sources/trenitalia.py b/MuoVErsi/trenitalia/source.py similarity index 97% rename from MuoVErsi/sources/trenitalia.py rename to MuoVErsi/trenitalia/source.py index 2c16bb6..cfb7c32 100644 --- a/MuoVErsi/sources/trenitalia.py +++ b/MuoVErsi/trenitalia/source.py @@ -1,16 +1,12 @@ import json -import logging import math import os -from datetime import datetime, timedelta, date -from urllib.parse import quote from pytz import timezone import requests -from sqlalchemy import and_, select from tqdm import tqdm -from MuoVErsi.sources.base import * +from MuoVErsi.base import * logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO diff --git a/alembic/env.py b/alembic/env.py index f76b1d8..2c4805a 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,8 +4,8 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool +from MuoVErsi.base.models import Base from MuoVErsi.handlers import engine_url -from MuoVErsi.sources.base import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/alembic/versions/6c9ef3a680e3_create_stops_table.py b/alembic/versions/6c9ef3a680e3_create_stops_table.py index 1dfd61a..3fec4a7 100644 --- a/alembic/versions/6c9ef3a680e3_create_stops_table.py +++ b/alembic/versions/6c9ef3a680e3_create_stops_table.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy.orm import sessionmaker -from MuoVErsi.sources.base import Station +from MuoVErsi.base import Station # revision identifiers, used by Alembic. revision = '6c9ef3a680e3' diff --git a/save_data.py b/save_data.py index 19c7176..9886c9c 100644 --- a/save_data.py +++ b/save_data.py @@ -3,8 +3,8 @@ from sqlalchemy.orm import sessionmaker from MuoVErsi.handlers import engine -from MuoVErsi.sources.trenitalia import Trenitalia -from MuoVErsi.sources.GTFS import GTFS +from MuoVErsi.trenitalia import Trenitalia +from MuoVErsi.GTFS import GTFS from MuoVErsi.typesense import connect_to_typesense logging.basicConfig( diff --git a/tests/test_db.py b/tests/test_db.py index fbb521c..95e622d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,7 +1,7 @@ from datetime import date, datetime, time import pytest -from MuoVErsi.sources.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop +from MuoVErsi.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop @pytest.fixture diff --git a/tests/test_gtfs_clustering.py b/tests/test_gtfs_clustering.py index e14c108..343d9c5 100644 --- a/tests/test_gtfs_clustering.py +++ b/tests/test_gtfs_clustering.py @@ -1,6 +1,6 @@ import pytest -from MuoVErsi.sources.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster +from MuoVErsi.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster @pytest.fixture