Skip to content

Commit

Permalink
Use Source upload_trip_stop_times_to_postgres and TripStopTime for Tr…
Browse files Browse the repository at this point in the history
…enitalia
  • Loading branch information
gsarrco committed Oct 1, 2023
1 parent bdf11e9 commit 94b5ad4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 51 deletions.
46 changes: 46 additions & 0 deletions MuoVErsi/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
class Liner:
def format(self, number, _, source_name):
raise NotImplementedError





class BaseStopTime(Liner):
Expand Down Expand Up @@ -177,6 +180,24 @@ class Stop(Base):
stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan')


class TripStopTime(BaseStopTime):
def __init__(self, stop: Station, origin_id, dep_time: datetime | None, stop_sequence, delay: int, platform,
headsign,
trip_id,
route_name,
arr_time: datetime = None,
origin_dep_time: int = None, destination: str = None):
if arr_time is None:
arr_time = dep_time
super().__init__(stop, dep_time, arr_time, stop_sequence, delay, platform, headsign, trip_id, route_name)
self.origin_dep_time = origin_dep_time
self.destination = destination
self.origin_id = origin_id

def merge(self, arr_stop_time: 'TripStopTime'):
self.arr_time = arr_stop_time.arr_time


class Source:
LIMIT = 7
MINUTES_TOLERANCE = 3
Expand Down Expand Up @@ -316,6 +337,31 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]:

def get_source_stations(self) -> list[Station]:
return self.session.scalars(select(Station).filter_by(source=self.name)).all()

def upload_trip_stop_times_to_postgres(self, stop_times: list[TripStopTime]):
for stop_time in stop_times:
train = self.session.query(Trip).filter_by(orig_id=stop_time.origin_id,
number=stop_time.trip_id,
orig_dep_date=stop_time.origin_dep_time).first()

if not train:
train = Trip(orig_id=stop_time.origin_id, dest_text=stop_time.destination,
number=stop_time.trip_id, orig_dep_date=stop_time.origin_dep_time,
route_name=stop_time.route_name)
self.session.add(train)
self.session.commit()

stop_time_db = self.session.query(StopTime).filter_by(trip_id=train.id, stop_id=stop_time.stop.id).first()

if stop_time_db:
if stop_time_db.platform != stop_time.platform:
stop_time_db.platform = stop_time.platform
self.session.commit()
else:
new_stop_time = StopTime(trip_id=train.id, stop_id=stop_time.stop.id, sched_arr_dt=stop_time.arr_time,
sched_dep_dt=stop_time.dep_time, platform=stop_time.platform)
self.session.add(new_stop_time)
self.session.commit()


class Trip(Base):
Expand Down
62 changes: 11 additions & 51 deletions MuoVErsi/sources/trenitalia.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,14 @@
from telegram.ext import ContextTypes
from tqdm import tqdm

from MuoVErsi.sources.base import Source, BaseStopTime, Route, Direction, Station, Stop, StopTime, Trip
from MuoVErsi.sources.base import *

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)


class TrenitaliaStopTime(BaseStopTime):
def __init__(self, stop: Station, origin_id, dep_time: datetime | None, stop_sequence, delay: int, platform,
headsign,
trip_id,
route_name,
arr_time: datetime = None,
origin_dep_time: int = None, destination: str = None):
if arr_time is None:
arr_time = dep_time
super().__init__(stop, dep_time, arr_time, stop_sequence, delay, platform, headsign, trip_id, route_name)
self.origin_dep_time = origin_dep_time
self.destination = destination
self.origin_id = origin_id

def merge(self, arr_stop_time: 'TrenitaliaStopTime'):
self.arr_time = arr_stop_time.arr_time


class TrenitaliaRoute(Route):
pass

Expand Down Expand Up @@ -76,35 +58,13 @@ def save_trains(self):
stop_times = self.get_stop_times_from_station(station)
total_times_count += len(stop_times)
times_count.append(len(stop_times))
for stop_time in stop_times:
train = self.session.query(Trip).filter_by(orig_id=stop_time.origin_id,
number=stop_time.trip_id,
orig_dep_date=stop_time.origin_dep_time).first()

if not train:
train = Trip(orig_id=stop_time.origin_id, dest_text=stop_time.destination,
number=stop_time.trip_id, orig_dep_date=stop_time.origin_dep_time,
route_name=stop_time.route_name)
self.session.add(train)
self.session.commit()

stop_time_db = self.session.query(StopTime).filter_by(trip_id=train.id, stop_id=station.id).first()

if stop_time_db:
if stop_time_db.platform != stop_time.platform:
stop_time_db.platform = stop_time.platform
self.session.commit()
else:
new_stop_time = StopTime(trip_id=train.id, stop_id=station.id, sched_arr_dt=stop_time.arr_time,
sched_dep_dt=stop_time.dep_time, platform=stop_time.platform)
self.session.add(new_stop_time)
self.session.commit()
self.upload_trip_stop_times_to_postgres(stop_times)

for i, station in enumerate(stations):
station.times_count = round(times_count[i] / total_times_count, int(math.log10(total_times_count)) + 1)
self.sync_stations_db(stations)

def get_stop_times_from_station(self, station) -> list[TrenitaliaStopTime]:
def get_stop_times_from_station(self, station) -> list[TripStopTime]:
now = datetime.now()
departures = self.loop_get_times(10000, station, now, type='partenze')
arrivals = self.loop_get_times(10000, station, now, type='arrivi')
Expand Down Expand Up @@ -197,16 +157,16 @@ def get_stop_times(self, stop: Station, line, start_time, day,
for raw_stop_time in raw_stop_times:
dep_time = raw_stop_time.dep_time
arr_time = raw_stop_time.arr_time
stop_time = TrenitaliaStopTime(stop, raw_stop_time.origin_id, dep_time, None, 0, raw_stop_time.platform,
stop_time = TripStopTime(stop, raw_stop_time.origin_id, dep_time, None, 0, raw_stop_time.platform,
raw_stop_time.destination, raw_stop_time.trip_id,
raw_stop_time.route_name, arr_time=arr_time,
origin_dep_time=raw_stop_time.origin_dep_time)
stop_times.append(stop_time)

return stop_times

def loop_get_times(self, limit, stop: Station, dt, train_ids=None, type='partenze') -> list[TrenitaliaStopTime]:
results: list[TrenitaliaStopTime] = []
def loop_get_times(self, limit, stop: Station, dt, train_ids=None, type='partenze') -> list[TripStopTime]:
results: list[TripStopTime] = []

notimes = 0

Expand Down Expand Up @@ -242,7 +202,7 @@ def loop_get_times(self, limit, stop: Station, dt, train_ids=None, type='partenz
return results[:limit]

def get_stop_times_from_start_dt(self, type, stop: Station, start_dt: datetime, train_ids: list[int] | None) -> \
list[TrenitaliaStopTime]:
list[TripStopTime]:
is_dst = start_dt.astimezone().dst() != timedelta(0)
date = (start_dt - timedelta(hours=(1 if is_dst else 0))).strftime("%a %b %d %Y %H:%M:%S GMT+0100")
url = f'http://www.viaggiatreno.it/infomobilita/resteasy/viaggiatreno/{type}/{stop.id}/{quote(date)}'
Expand Down Expand Up @@ -297,7 +257,7 @@ def get_stop_times_from_start_dt(self, type, stop: Station, start_dt: datetime,
origin_id = departure['codOrigine']
destination = departure.get('destinazione')
route_name = 'RV' if 3000 <= trip_id < 4000 else 'R'
stop_time = TrenitaliaStopTime(stop, origin_id, dep_time, stop_sequence, delay, platform, headsign, trip_id,
stop_time = TripStopTime(stop, origin_id, dep_time, stop_sequence, delay, platform, headsign, trip_id,
route_name,
arr_time=arr_time, origin_dep_time=origin_dep_time, destination=destination)
stop_times.append(stop_time)
Expand Down Expand Up @@ -378,12 +338,12 @@ def get_stop_times_between_stops(self, dep_stop: Station, arr_stop: Station, lin
d_arr_time = raw_stop_time.d_arr_time
a_dep_time = raw_stop_time.a_dep_time
a_arr_time = raw_stop_time.a_arr_time
d_stop_time = TrenitaliaStopTime(
d_stop_time = TripStopTime(
dep_stop, raw_stop_time.origin_id, d_dep_time, None, 0, raw_stop_time.d_platform,
raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name,
arr_time=d_arr_time, origin_dep_time=raw_stop_time.origin_dep_time)

a_stop_time = TrenitaliaStopTime(
a_stop_time = TripStopTime(
arr_stop, raw_stop_time.origin_id, a_dep_time, None, 0, raw_stop_time.a_platform,
raw_stop_time.destination, raw_stop_time.trip_id, raw_stop_time.route_name,
arr_time=a_arr_time, origin_dep_time=raw_stop_time.origin_dep_time)
Expand All @@ -408,7 +368,7 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]:

stop_times = []
for result in results:
stop_time = TrenitaliaStopTime(result.Station, result.Trip.orig_id, result.StopTime.sched_dep_dt,
stop_time = TripStopTime(result.Station, result.Trip.orig_id, result.StopTime.sched_dep_dt,
None, 0,
result.StopTime.platform, result.Trip.dest_text, trip_id,
result.Trip.route_name,
Expand Down

0 comments on commit 94b5ad4

Please sign in to comment.