diff --git a/MuoVErsi/sources/base.py b/MuoVErsi/sources/base.py index 0d03cdf..e50a88b 100644 --- a/MuoVErsi/sources/base.py +++ b/MuoVErsi/sources/base.py @@ -16,6 +16,9 @@ class Liner: def format(self, number, _, source_name): raise NotImplementedError + + + class BaseStopTime(Liner): @@ -177,6 +180,24 @@ class Stop(Base): stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan') +class TripStopTime(BaseStopTime): + def __init__(self, stop: Station, origin_id, dep_time: datetime | None, stop_sequence, delay: int, platform, + headsign, + trip_id, + route_name, + arr_time: datetime = None, + origin_dep_time: int = None, destination: str = None): + if arr_time is None: + arr_time = dep_time + super().__init__(stop, dep_time, arr_time, stop_sequence, delay, platform, headsign, trip_id, route_name) + self.origin_dep_time = origin_dep_time + self.destination = destination + self.origin_id = origin_id + + def merge(self, arr_stop_time: 'TripStopTime'): + self.arr_time = arr_stop_time.arr_time + + class Source: LIMIT = 7 MINUTES_TOLERANCE = 3 @@ -316,6 +337,31 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: def get_source_stations(self) -> list[Station]: return self.session.scalars(select(Station).filter_by(source=self.name)).all() + + def upload_trip_stop_times_to_postgres(self, stop_times: list[TripStopTime]): + for stop_time in stop_times: + train = self.session.query(Trip).filter_by(orig_id=stop_time.origin_id, + number=stop_time.trip_id, + orig_dep_date=stop_time.origin_dep_time).first() + + if not train: + train = Trip(orig_id=stop_time.origin_id, dest_text=stop_time.destination, + number=stop_time.trip_id, orig_dep_date=stop_time.origin_dep_time, + route_name=stop_time.route_name) + self.session.add(train) + self.session.commit() + + stop_time_db = self.session.query(StopTime).filter_by(trip_id=train.id, stop_id=stop_time.stop.id).first() + + if stop_time_db: + if stop_time_db.platform != stop_time.platform: + stop_time_db.platform = stop_time.platform + self.session.commit() + else: + new_stop_time = StopTime(trip_id=train.id, stop_id=stop_time.stop.id, sched_arr_dt=stop_time.arr_time, + sched_dep_dt=stop_time.dep_time, platform=stop_time.platform) + self.session.add(new_stop_time) + self.session.commit() class Trip(Base): diff --git a/MuoVErsi/sources/trenitalia.py b/MuoVErsi/sources/trenitalia.py index 8b93812..2e88769 100644 --- a/MuoVErsi/sources/trenitalia.py +++ b/MuoVErsi/sources/trenitalia.py @@ -11,7 +11,7 @@ from telegram.ext import ContextTypes from tqdm import tqdm -from MuoVErsi.sources.base import Source, BaseStopTime, Route, Direction, Station, Stop, StopTime, Trip +from MuoVErsi.sources.base import * logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -19,24 +19,6 @@ logger = logging.getLogger(__name__) -class TrenitaliaStopTime(BaseStopTime): - def __init__(self, stop: Station, origin_id, dep_time: datetime | None, stop_sequence, delay: int, platform, - headsign, - trip_id, - route_name, - arr_time: datetime = None, - origin_dep_time: int = None, destination: str = None): - if arr_time is None: - arr_time = dep_time - super().__init__(stop, dep_time, arr_time, stop_sequence, delay, platform, headsign, trip_id, route_name) - self.origin_dep_time = origin_dep_time - self.destination = destination - self.origin_id = origin_id - - def merge(self, arr_stop_time: 'TrenitaliaStopTime'): - self.arr_time = arr_stop_time.arr_time - - class TrenitaliaRoute(Route): pass @@ -76,35 +58,13 @@ def save_trains(self): stop_times = self.get_stop_times_from_station(station) total_times_count += len(stop_times) times_count.append(len(stop_times)) - for stop_time in stop_times: - train = self.session.query(Trip).filter_by(orig_id=stop_time.origin_id, - number=stop_time.trip_id, - orig_dep_date=stop_time.origin_dep_time).first() - - if not train: - train = Trip(orig_id=stop_time.origin_id, dest_text=stop_time.destination, - number=stop_time.trip_id, orig_dep_date=stop_time.origin_dep_time, - route_name=stop_time.route_name) - self.session.add(train) - self.session.commit() - - stop_time_db = self.session.query(StopTime).filter_by(trip_id=train.id, stop_id=station.id).first() - - if stop_time_db: - if stop_time_db.platform != stop_time.platform: - stop_time_db.platform = stop_time.platform - self.session.commit() - else: - new_stop_time = StopTime(trip_id=train.id, stop_id=station.id, sched_arr_dt=stop_time.arr_time, - sched_dep_dt=stop_time.dep_time, platform=stop_time.platform) - self.session.add(new_stop_time) - self.session.commit() + self.upload_trip_stop_times_to_postgres(stop_times) for i, station in enumerate(stations): station.times_count = round(times_count[i] / total_times_count, int(math.log10(total_times_count)) + 1) self.sync_stations_db(stations) - def get_stop_times_from_station(self, station) -> list[TrenitaliaStopTime]: + def get_stop_times_from_station(self, station) -> list[TripStopTime]: now = datetime.now() departures = self.loop_get_times(10000, station, now, type='partenze') arrivals = self.loop_get_times(10000, station, now, type='arrivi') @@ -197,7 +157,7 @@ def get_stop_times(self, stop: Station, line, start_time, day, for raw_stop_time in raw_stop_times: dep_time = raw_stop_time.dep_time arr_time = raw_stop_time.arr_time - stop_time = TrenitaliaStopTime(stop, raw_stop_time.origin_id, dep_time, None, 0, raw_stop_time.platform, + stop_time = TripStopTime(stop, raw_stop_time.origin_id, dep_time, None, 0, raw_stop_time.platform, raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name, arr_time=arr_time, origin_dep_time=raw_stop_time.origin_dep_time) @@ -205,8 +165,8 @@ def get_stop_times(self, stop: Station, line, start_time, day, return stop_times - def loop_get_times(self, limit, stop: Station, dt, train_ids=None, type='partenze') -> list[TrenitaliaStopTime]: - results: list[TrenitaliaStopTime] = [] + def loop_get_times(self, limit, stop: Station, dt, train_ids=None, type='partenze') -> list[TripStopTime]: + results: list[TripStopTime] = [] notimes = 0 @@ -242,7 +202,7 @@ def loop_get_times(self, limit, stop: Station, dt, train_ids=None, type='partenz return results[:limit] def get_stop_times_from_start_dt(self, type, stop: Station, start_dt: datetime, train_ids: list[int] | None) -> \ - list[TrenitaliaStopTime]: + list[TripStopTime]: is_dst = start_dt.astimezone().dst() != timedelta(0) date = (start_dt - timedelta(hours=(1 if is_dst else 0))).strftime("%a %b %d %Y %H:%M:%S GMT+0100") url = f'http://www.viaggiatreno.it/infomobilita/resteasy/viaggiatreno/{type}/{stop.id}/{quote(date)}' @@ -297,7 +257,7 @@ def get_stop_times_from_start_dt(self, type, stop: Station, start_dt: datetime, origin_id = departure['codOrigine'] destination = departure.get('destinazione') route_name = 'RV' if 3000 <= trip_id < 4000 else 'R' - stop_time = TrenitaliaStopTime(stop, origin_id, dep_time, stop_sequence, delay, platform, headsign, trip_id, + stop_time = TripStopTime(stop, origin_id, dep_time, stop_sequence, delay, platform, headsign, trip_id, route_name, arr_time=arr_time, origin_dep_time=origin_dep_time, destination=destination) stop_times.append(stop_time) @@ -378,12 +338,12 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin d_arr_time = raw_stop_time.d_arr_time a_dep_time = raw_stop_time.a_dep_time a_arr_time = raw_stop_time.a_arr_time - d_stop_time = TrenitaliaStopTime( + d_stop_time = TripStopTime( dep_stop, raw_stop_time.origin_id, d_dep_time, None, 0, raw_stop_time.d_platform, raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name, arr_time=d_arr_time, origin_dep_time=raw_stop_time.origin_dep_time) - a_stop_time = TrenitaliaStopTime( + a_stop_time = TripStopTime( arr_stop, raw_stop_time.origin_id, a_dep_time, None, 0, raw_stop_time.a_platform, raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name, arr_time=a_arr_time, origin_dep_time=raw_stop_time.origin_dep_time) @@ -408,7 +368,7 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: stop_times = [] for result in results: - stop_time = TrenitaliaStopTime(result.Station, result.Trip.orig_id, result.StopTime.sched_dep_dt, + stop_time = TripStopTime(result.Station, result.Trip.orig_id, result.StopTime.sched_dep_dt, None, 0, result.StopTime.platform, result.Trip.dest_text, trip_id, result.Trip.route_name,