Skip to content

Commit

Permalink
refactor: better organize files (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarrco authored Oct 29, 2023
2 parents 86484a5 + 6959694 commit 8f8f404
Show file tree
Hide file tree
Showing 17 changed files with 85 additions and 80 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
9 changes: 4 additions & 5 deletions MuoVErsi/sources/GTFS/source.py → MuoVErsi/GTFS/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

import requests
from bs4 import BeautifulSoup
from telegram.ext import ContextTypes
from tqdm import tqdm

from MuoVErsi.sources.base import Source, BaseStopTime, Route, Direction, Station, Stop, TripStopTime
from MuoVErsi.base import Source, Station, Stop, TripStopTime
from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster
from .models import CStop

from sqlalchemy import or_, select, func
from sqlalchemy import select, func

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -79,7 +78,7 @@ def __init__(self, transport_type, emoji, session, typesense, gtfs_versions_rang

def file_path(self, ext, gtfs_version):
current_dir = os.path.abspath(os.path.dirname(__file__))
parent_dir = os.path.abspath(current_dir + f"/../../../{self.location}")
parent_dir = os.path.abspath(current_dir + f"/../../{self.location}")

return os.path.join(parent_dir, f'{self.transport_type}_{gtfs_version}.{ext}')

Expand Down Expand Up @@ -311,7 +310,7 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim

def search_lines(self, name):
today = date.today()
from MuoVErsi.sources.base import Trip
from MuoVErsi.base import Trip
trips = self.session.execute(
select(func.max(Trip.number), Trip.dest_text)\
.filter(Trip.orig_dep_date == today)\
Expand Down
1 change: 1 addition & 0 deletions MuoVErsi/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .source import *
63 changes: 63 additions & 0 deletions MuoVErsi/base/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from datetime import date, datetime
from typing import Optional

from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import declarative_base, Mapped, mapped_column, relationship

Base = declarative_base()


class Station(Base):
__tablename__ = 'stations'

id: Mapped[str] = mapped_column(primary_key=True)
name: Mapped[str]
lat: Mapped[Optional[float]]
lon: Mapped[Optional[float]]
ids: Mapped[str] = mapped_column(server_default='')
times_count: Mapped[float] = mapped_column(server_default='0')
source: Mapped[str] = mapped_column(server_default='treni')
stops = relationship('Stop', back_populates='station', cascade='all, delete-orphan')


class Stop(Base):
__tablename__ = 'stops'

id: Mapped[str] = mapped_column(primary_key=True)
platform: Mapped[Optional[str]]
lat: Mapped[float]
lon: Mapped[float]
station_id: Mapped[str] = mapped_column(ForeignKey('stations.id'))
station: Mapped[Station] = relationship('Station', back_populates='stops')
source: Mapped[Optional[str]]
stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan')


class Trip(Base):
__tablename__ = 'trips'

id: Mapped[int] = mapped_column(primary_key=True)
orig_id: Mapped[str]
dest_text: Mapped[str]
number: Mapped[int]
orig_dep_date: Mapped[date]
route_name: Mapped[str]
source: Mapped[str] = mapped_column(server_default='treni')
stop_times = relationship('StopTime', back_populates='trip', cascade='all, delete-orphan', passive_deletes=True)

__table_args__ = (UniqueConstraint('source', 'number', 'orig_dep_date'),)


class StopTime(Base):
__tablename__ = 'stop_times'

id: Mapped[int] = mapped_column(primary_key=True)
trip_id: Mapped[int] = mapped_column(ForeignKey('trips.id', ondelete='CASCADE'))
trip: Mapped[Trip] = relationship('Trip', back_populates='stop_times')
stop_id: Mapped[str] = mapped_column(ForeignKey('stops.id'))
stop: Mapped[Stop] = relationship('Stop', back_populates='stop_times')
sched_arr_dt: Mapped[Optional[datetime]]
sched_dep_dt: Mapped[Optional[datetime]]
platform: Mapped[Optional[str]]

__table_args__ = (UniqueConstraint('trip_id', 'stop_id'),)
65 changes: 5 additions & 60 deletions MuoVErsi/sources/base.py → MuoVErsi/base/source.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from datetime import datetime, date, timedelta, time
from typing import Optional

from sqlalchemy import select, ForeignKey, UniqueConstraint, func, and_
from sqlalchemy import select, func, and_
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Mapped, mapped_column, relationship, declarative_base, aliased
from sqlalchemy.orm import aliased
from telegram.ext import ContextTypes

from .models import Station, Stop, Trip, StopTime

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
Expand Down Expand Up @@ -151,34 +152,6 @@ def format(self, number, _, source_name):

return text

Base = declarative_base()


class Station(Base):
__tablename__ = 'stations'

id: Mapped[str] = mapped_column(primary_key=True)
name: Mapped[str]
lat: Mapped[Optional[float]]
lon: Mapped[Optional[float]]
ids: Mapped[str] = mapped_column(server_default='')
times_count: Mapped[float] = mapped_column(server_default='0')
source: Mapped[str] = mapped_column(server_default='treni')
stops = relationship('Stop', back_populates='station', cascade='all, delete-orphan')


class Stop(Base):
__tablename__ = 'stops'

id: Mapped[str] = mapped_column(primary_key=True)
platform: Mapped[Optional[str]]
lat: Mapped[float]
lon: Mapped[float]
station_id: Mapped[str] = mapped_column(ForeignKey('stations.id'))
station: Mapped[Station] = relationship('Station', back_populates='stops')
source: Mapped[Optional[str]]
stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan')


class TripStopTime(BaseStopTime):
def __init__(self, station: Station, origin_id, dep_time: datetime | None, stop_sequence, delay: int, platform,
Expand Down Expand Up @@ -393,7 +366,7 @@ def get_stop_times_between_stops(self, dep_station: Station, arr_station: Statio
raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name,
arr_time=a_arr_time, orig_dep_date=raw_stop_time.orig_dep_date)

from MuoVErsi.sources.trenitalia import TrenitaliaRoute
from MuoVErsi.trenitalia import TrenitaliaRoute
route = TrenitaliaRoute(d_stop_time, a_stop_time)
directions.append(Direction([route]))

Expand Down Expand Up @@ -537,31 +510,3 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]:
return stop_times


class Trip(Base):
__tablename__ = 'trips'

id: Mapped[int] = mapped_column(primary_key=True)
orig_id: Mapped[str]
dest_text: Mapped[str]
number: Mapped[int]
orig_dep_date: Mapped[date]
route_name: Mapped[str]
source: Mapped[str] = mapped_column(server_default='treni')
stop_times = relationship('StopTime', back_populates='trip', cascade='all, delete-orphan', passive_deletes=True)

__table_args__ = (UniqueConstraint('source', 'number', 'orig_dep_date'),)


class StopTime(Base):
__tablename__ = 'stop_times'

id: Mapped[int] = mapped_column(primary_key=True)
trip_id: Mapped[int] = mapped_column(ForeignKey('trips.id', ondelete='CASCADE'))
trip: Mapped[Trip] = relationship('Trip', back_populates='stop_times')
stop_id: Mapped[str] = mapped_column(ForeignKey('stops.id'))
stop: Mapped[Stop] = relationship('Stop', back_populates='stop_times')
sched_arr_dt: Mapped[Optional[datetime]]
sched_dep_dt: Mapped[Optional[datetime]]
platform: Mapped[Optional[str]]

__table_args__ = (UniqueConstraint('trip_id', 'stop_id'),)
File renamed without changes.
6 changes: 3 additions & 3 deletions MuoVErsi/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
filters, CallbackQueryHandler, )

from .persistence import SQLitePersistence
from .sources.GTFS import GTFS
from .sources.base import Source
from .sources.trenitalia import Trenitalia
from .GTFS import GTFS
from .base import Source
from .trenitalia import Trenitalia
from .stop_times_filter import StopTimesFilter
from .typesense import connect_to_typesense

Expand Down
2 changes: 1 addition & 1 deletion MuoVErsi/stop_times_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from telegram import InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import ContextTypes

from MuoVErsi.sources.base import Source, Liner, Station
from MuoVErsi.base import Source, Liner, Station

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down
1 change: 1 addition & 0 deletions MuoVErsi/trenitalia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .source import *
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import json
import logging
import math
import os
from datetime import datetime, timedelta, date
from urllib.parse import quote
from pytz import timezone

import requests
from sqlalchemy import and_, select
from tqdm import tqdm

from MuoVErsi.sources.base import *
from MuoVErsi.base import *

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down
2 changes: 1 addition & 1 deletion alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy import engine_from_config
from sqlalchemy import pool

from MuoVErsi.base.models import Base
from MuoVErsi.handlers import engine_url
from MuoVErsi.sources.base import Base

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down
2 changes: 1 addition & 1 deletion alembic/versions/6c9ef3a680e3_create_stops_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from alembic import op
from sqlalchemy.orm import sessionmaker

from MuoVErsi.sources.base import Station
from MuoVErsi.base import Station

# revision identifiers, used by Alembic.
revision = '6c9ef3a680e3'
Expand Down
4 changes: 2 additions & 2 deletions save_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from sqlalchemy.orm import sessionmaker

from MuoVErsi.handlers import engine
from MuoVErsi.sources.trenitalia import Trenitalia
from MuoVErsi.sources.GTFS import GTFS
from MuoVErsi.trenitalia import Trenitalia
from MuoVErsi.GTFS import GTFS
from MuoVErsi.typesense import connect_to_typesense

logging.basicConfig(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import date, datetime, time
import pytest

from MuoVErsi.sources.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop
from MuoVErsi.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gtfs_clustering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from MuoVErsi.sources.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster
from MuoVErsi.GTFS.clustering import get_root_from_stop_name, get_loc_from_stop_and_cluster


@pytest.fixture
Expand Down

0 comments on commit 8f8f404

Please sign in to comment.