diff --git a/alembic/env.py b/alembic/env.py index ba927c1..e57e979 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -5,12 +5,14 @@ from sqlalchemy import pool from server.base.models import Base -from server.sources import engine_url +from config import config as engine_config # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config +engine_url = f"postgresql://{engine_config['PGUSER']}:{engine_config['PGPASSWORD']}@{engine_config['PGHOST']}:{engine_config['PGPORT']}/" \ + f"{engine_config['PGDATABASE']}" config.set_main_option('sqlalchemy.url', engine_url) # Interpret the config file for Python logging. diff --git a/alembic/versions/3e63dbd74ceb_add_active_fields_to_stations_and_stops.py b/alembic/versions/3e63dbd74ceb_add_active_fields_to_stations_and_stops.py new file mode 100644 index 0000000..b58cdc3 --- /dev/null +++ b/alembic/versions/3e63dbd74ceb_add_active_fields_to_stations_and_stops.py @@ -0,0 +1,26 @@ +"""Add active fields to stations and stops + +Revision ID: 3e63dbd74ceb +Revises: 1f2c7b1eec8b +Create Date: 2023-11-09 16:37:09.843639 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3e63dbd74ceb' +down_revision = '1f2c7b1eec8b' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('stations', sa.Column('active', sa.Boolean(), server_default='true', nullable=False)) + op.add_column('stops', sa.Column('active', sa.Boolean(), server_default='true', nullable=False)) + + +def downgrade() -> None: + op.drop_column('stops', 'active') + op.drop_column('stations', 'active') diff --git a/alembic/versions/7c12f6bfe3c6_merge_trip_into_stoptime.py b/alembic/versions/7c12f6bfe3c6_merge_trip_into_stoptime.py new file mode 100644 index 0000000..41a3a1d --- /dev/null +++ b/alembic/versions/7c12f6bfe3c6_merge_trip_into_stoptime.py @@ -0,0 +1,80 @@ +"""Merge Trip into StopTime + +Revision ID: 7c12f6bfe3c6 +Revises: 1f2c7b1eec8b +Create Date: 2023-11-05 09:16:46.640362 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '7c12f6bfe3c6' +down_revision = '3e63dbd74ceb' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_constraint('stop_times_trip_id_fkey', 'stop_times', type_='foreignkey') + + # Create new fields as nullable true temporarily + op.add_column('stop_times', sa.Column('orig_id', sa.String(), nullable=True)) + op.add_column('stop_times', sa.Column('dest_text', sa.String(), nullable=True)) + op.add_column('stop_times', sa.Column('number', sa.Integer(), nullable=True)) + op.add_column('stop_times', sa.Column('orig_dep_date', sa.Date(), nullable=True)) + op.add_column('stop_times', sa.Column('route_name', sa.String(), nullable=True)) + op.add_column('stop_times', sa.Column('source', sa.String(), server_default='treni', nullable=True)) + + # populate new fields with data from trips through stop_times.trip_id + op.execute(''' + UPDATE stop_times + SET + orig_id = trips.orig_id, + dest_text = trips.dest_text, + number = trips.number, + orig_dep_date = trips.orig_dep_date, + route_name = trips.route_name, + source = trips.source + FROM trips + WHERE stop_times.trip_id = trips.id + ''') + + # convert new fields to not nullable + op.alter_column('stop_times', 'orig_id', nullable=False) + op.alter_column('stop_times', 'dest_text', nullable=False) + op.alter_column('stop_times', 'number', nullable=False) + op.alter_column('stop_times', 'orig_dep_date', nullable=False) + op.alter_column('stop_times', 'route_name', nullable=False) + op.alter_column('stop_times', 'source', nullable=False) + + # drop trip_id column + op.drop_column('stop_times', 'trip_id') + + # drop trips table + op.drop_table('trips') + + +def downgrade() -> None: + op.create_table('trips', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('orig_id', sa.String(), autoincrement=False, nullable=False), + sa.Column('dest_text', sa.String(), autoincrement=False, nullable=False), + sa.Column('number', sa.Integer(), autoincrement=False, nullable=False), + sa.Column('orig_dep_date', sa.Date(), autoincrement=False, nullable=False), + sa.Column('route_name', sa.String(), autoincrement=False, nullable=False), + sa.Column('source', sa.String(), server_default='treni', autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('source', 'number', 'orig_dep_date', + name='trips_source_number_orig_dep_date_key') + ) + + op.add_column('stop_times', sa.Column('trip_id', sa.INTEGER(), autoincrement=False, nullable=False)) + op.create_foreign_key('stop_times_trip_id_fkey', 'stop_times', 'trips', ['trip_id'], ['id'], ondelete='CASCADE') + + op.drop_column('stop_times', 'source') + op.drop_column('stop_times', 'route_name') + op.drop_column('stop_times', 'orig_dep_date') + op.drop_column('stop_times', 'number') + op.drop_column('stop_times', 'dest_text') + op.drop_column('stop_times', 'orig_id') diff --git a/alembic/versions/d55702afa188_convert_stop_times_to_partitioned_table.py b/alembic/versions/d55702afa188_convert_stop_times_to_partitioned_table.py new file mode 100644 index 0000000..2cb3289 --- /dev/null +++ b/alembic/versions/d55702afa188_convert_stop_times_to_partitioned_table.py @@ -0,0 +1,54 @@ +"""Convert stop_times to partitioned table + +Revision ID: d55702afa188 +Revises: 1f2c7b1eec8b +Create Date: 2023-10-29 16:09:44.815425 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = 'd55702afa188' +down_revision = '7c12f6bfe3c6' +branch_labels = None +depends_on = None + + +# Define the migration +def upgrade(): + # remove foreign key "stop_times_stop_id_fkey" + op.drop_constraint('stop_times_stop_id_fkey', 'stop_times', type_='foreignkey') + + # rename table "stop_times" to "stop_times_reg" + op.rename_table('stop_times', 'stop_times_reg') + + # create the partitioned table "stop_times" for field "orig_dep_date" + op.execute(""" + CREATE TABLE stop_times ( + id SERIAL NOT NULL, + stop_id character varying NOT NULL, + sched_arr_dt timestamp without time zone, + sched_dep_dt timestamp without time zone, + platform character varying, + orig_dep_date date NOT NULL, + orig_id character varying NOT NULL, + dest_text character varying NOT NULL, + number integer NOT NULL, + route_name character varying NOT NULL, + source character varying, + CONSTRAINT stop_times_stop_id_fkey FOREIGN key(stop_id) REFERENCES stops(id) + ) PARTITION BY RANGE (orig_dep_date); + CREATE UNIQUE INDEX stop_times_unique_idx ON stop_times(stop_id, number, source, orig_dep_date); + """) + + +def downgrade(): + # drop the partitioned table "stop_times" + op.drop_table('stop_times') + + # rename table "stop_times_reg" to "stop_times" + op.rename_table('stop_times_reg', 'stop_times') + + # add foreign key "stop_times_stop_id_fkey" + op.create_foreign_key('stop_times_stop_id_fkey', 'stop_times', 'stops', ['stop_id'], ['id']) diff --git a/save_data.py b/save_data.py index d471b54..d42ed05 100644 --- a/save_data.py +++ b/save_data.py @@ -1,9 +1,11 @@ import logging +from datetime import date, timedelta -from sqlalchemy.orm import sessionmaker +from sqlalchemy import inspect from server.GTFS import GTFS -from server.sources import engine +from server.base.models import StopTime +from server.sources import engine, session from server.trenitalia import Trenitalia from server.typesense import connect_to_typesense @@ -20,7 +22,6 @@ def run(): args = parser.parse_args() force_update_stations = args.force_update_stations - session = sessionmaker(bind=engine)() typesense = connect_to_typesense() sources = [ @@ -29,6 +30,24 @@ def run(): Trenitalia(session, typesense, force_update_stations=force_update_stations), ] + session.commit() + + today = date.today() + + for i in range(3): + day: date = today + timedelta(days=i) + partition = StopTime.create_partition(day) + if not inspect(engine).has_table(partition.__table__.name): + partition.__table__.create(bind=engine) + + while True: + i = -2 + day = today + timedelta(days=i) + try: + StopTime.detach_partition(day) + except Exception: + break + for source in sources: try: source.save_data() diff --git a/server/GTFS/source.py b/server/GTFS/source.py index a3cfb6f..9e1d0f4 100644 --- a/server/GTFS/source.py +++ b/server/GTFS/source.py @@ -15,7 +15,7 @@ from sqlalchemy import select, func from tqdm import tqdm -from server.base import Source, Station, Stop, TripStopTime +from server.base import Source, Station, Stop, TripStopTime, StopTime from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster from .models import CStop @@ -309,13 +309,12 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim def search_lines(self, name): today = date.today() - from server.base import Trip trips = self.session.execute( - select(func.max(Trip.number), Trip.dest_text)\ - .filter(Trip.orig_dep_date == today)\ - .filter(Trip.route_name == name)\ - .group_by(Trip.dest_text)\ - .order_by(func.count(Trip.id).desc()))\ + select(func.max(StopTime.number), StopTime.dest_text) \ + .filter(StopTime.orig_dep_date == today) \ + .filter(StopTime.route_name == name) \ + .group_by(StopTime.dest_text) \ + .order_by(func.count(StopTime.number).desc())) \ .all() results = [(trip[0], name, trip[1]) for trip in trips] diff --git a/server/base/models.py b/server/base/models.py index 6798573..3bae60c 100644 --- a/server/base/models.py +++ b/server/base/models.py @@ -1,8 +1,10 @@ -from datetime import date, datetime +from datetime import date, datetime, timedelta from typing import Optional -from sqlalchemy import ForeignKey, UniqueConstraint -from sqlalchemy.orm import declarative_base, Mapped, mapped_column, relationship +from sqlalchemy import ForeignKey, UniqueConstraint, event +from sqlalchemy.orm import declarative_base, Mapped, mapped_column, relationship, declared_attr +from sqlalchemy.ext.declarative import DeclarativeMeta +from sqlalchemy.sql.ddl import DDL Base = declarative_base() @@ -18,6 +20,7 @@ class Station(Base): times_count: Mapped[float] = mapped_column(server_default='0') source: Mapped[str] = mapped_column(server_default='treni') stops = relationship('Stop', back_populates='station', cascade='all, delete-orphan') + active: Mapped[bool] = mapped_column(server_default='true') def as_dict(self): return { @@ -28,6 +31,7 @@ def as_dict(self): 'source': self.source } + class Stop(Base): __tablename__ = 'stops' @@ -38,34 +42,102 @@ class Stop(Base): station_id: Mapped[str] = mapped_column(ForeignKey('stations.id')) station: Mapped[Station] = relationship('Station', back_populates='stops') source: Mapped[Optional[str]] - stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan') - - -class Trip(Base): - __tablename__ = 'trips' - + active: Mapped[bool] = mapped_column(server_default='true') + + +class PartitionByOrigDepDateMeta(DeclarativeMeta): + def __new__(cls, clsname, bases, attrs, *, partition_by): + @classmethod + def get_partition_name(cls_, key): + return f'{cls_.__tablename__}_{key}' + + @classmethod + def create_partition(cls_, day: date): + key = day.strftime('%Y%m%d') + if key not in cls_.partitions: + Partition = type( + f'{clsname}{key}', + bases, + {'__tablename__': cls_.get_partition_name(key)} + ) + + Partition.__table__.add_is_dependent_on(cls_.__table__) + + day_plus_one = day + timedelta(days=1) + event.listen( + Partition.__table__, + 'after_create', + DDL( + f""" + ALTER TABLE {cls_.__tablename__} + ATTACH PARTITION {Partition.__tablename__} + FOR VALUES FROM ('{day}') TO ('{day_plus_one}') + """ + ) + ) + + cls_.partitions[key] = Partition + + return cls_.partitions[key] + + @classmethod + def detach_partition(cls_, day: date): + key = day.strftime('%Y%m%d') + if key not in cls_.partitions: + raise Exception(f'Partition {key} does not exist') + Partition = type( + f'{clsname}{key}', + bases, + {'__tablename__': cls_.get_partition_name(key)} + ) + event.listen( + Partition.__table__, + 'after_create', + DDL( + f""" + ALTER TABLE {cls_.__tablename__} + DETACH PARTITION {Partition.__tablename__} + """ + ) + ) + + attrs.update( + { + '__table_args__': attrs.get('__table_args__', ()) + + (dict(postgresql_partition_by=f'RANGE({partition_by})'),), + 'partitions': {}, + 'partitioned_by': partition_by, + 'get_partition_name': get_partition_name, + 'create_partition': create_partition, + 'detach_partition': detach_partition + } + ) + + return super().__new__(cls, clsname, bases, attrs) + + +class StopTimeMixin: id: Mapped[int] = mapped_column(primary_key=True) + sched_arr_dt: Mapped[Optional[datetime]] + sched_dep_dt: Mapped[Optional[datetime]] + orig_dep_date: Mapped[date] + platform: Mapped[Optional[str]] orig_id: Mapped[str] dest_text: Mapped[str] number: Mapped[int] - orig_dep_date: Mapped[date] route_name: Mapped[str] source: Mapped[str] = mapped_column(server_default='treni') - stop_times = relationship('StopTime', back_populates='trip', cascade='all, delete-orphan', passive_deletes=True) - __table_args__ = (UniqueConstraint('source', 'number', 'orig_dep_date'),) + @declared_attr + def stop_id(self) -> Mapped[str]: + return mapped_column(ForeignKey('stops.id')) + @declared_attr + def stop(self) -> Mapped[Stop]: + return relationship('Stop', foreign_keys=self.stop_id) -class StopTime(Base): - __tablename__ = 'stop_times' - id: Mapped[int] = mapped_column(primary_key=True) - trip_id: Mapped[int] = mapped_column(ForeignKey('trips.id', ondelete='CASCADE')) - trip: Mapped[Trip] = relationship('Trip', back_populates='stop_times') - stop_id: Mapped[str] = mapped_column(ForeignKey('stops.id')) - stop: Mapped[Stop] = relationship('Stop', back_populates='stop_times') - sched_arr_dt: Mapped[Optional[datetime]] - sched_dep_dt: Mapped[Optional[datetime]] - platform: Mapped[Optional[str]] +class StopTime(StopTimeMixin, Base, metaclass=PartitionByOrigDepDateMeta, partition_by='orig_dep_date'): + __tablename__ = 'stop_times' - __table_args__ = (UniqueConstraint('trip_id', 'stop_id'),) + __table_args__ = (UniqueConstraint("stop_id", "number", "source", "orig_dep_date"),) diff --git a/server/base/source.py b/server/base/source.py index 1204e3a..7fa8eb1 100644 --- a/server/base/source.py +++ b/server/base/source.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import aliased from telegram.ext import ContextTypes -from .models import Station, Stop, Trip, StopTime +from .models import Station, Stop, StopTime logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -17,13 +17,11 @@ class Liner: def format(self, number, _, source_name): raise NotImplementedError - - - class BaseStopTime(Liner): - def __init__(self, station: 'Station', dep_time: datetime | None, arr_time: datetime | None, stop_sequence, delay: int, + def __init__(self, station: 'Station', dep_time: datetime | None, arr_time: datetime | None, stop_sequence, + delay: int, platform, headsign, trip_id, route_name): @@ -181,7 +179,8 @@ def __init__(self, name, emoji, session, typesense): self.session = session self.typesense = typesense - def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sources=False) -> tuple[list[Station], int]: + def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sources=False) -> tuple[ + list[Station], int]: search_config = {'per_page': limit, 'query_by': 'name', 'page': page} limit_hits = None @@ -213,7 +212,7 @@ def search_stops(self, name=None, lat=None, lon=None, page=1, limit=4, all_sourc found = limit_hits if limit_hits else results['found'] return stations, found - + def get_stop_times(self, station: Station, line, start_time, day, offset_times, count=False, limit=True): day_start = datetime.combine(day, time(0)) @@ -229,25 +228,27 @@ def get_stop_times(self, station: Station, line, start_time, day, if count: raw_stop_times = self.session.query( - Trip.route_name.label('route_name') + StopTime.route_name.label('route_name') ) else: raw_stop_times = self.session.query( StopTime.sched_arr_dt.label('arr_time'), StopTime.sched_dep_dt.label('dep_time'), - Trip.orig_id.label('origin_id'), - Trip.dest_text.label('destination'), - Trip.number.label('trip_id'), - Trip.orig_dep_date.label('orig_dep_date'), + StopTime.orig_id.label('origin_id'), + StopTime.dest_text.label('destination'), + StopTime.number.label('trip_id'), + StopTime.orig_dep_date.label('orig_dep_date'), StopTime.platform.label('platform'), - Trip.route_name.label('route_name') + StopTime.route_name.label('route_name') ) + day_minus_one = day - timedelta(days=1) + raw_stop_times = raw_stop_times \ .select_from(StopTime) \ - .join(Trip, StopTime.trip_id == Trip.id) \ .filter( and_( + StopTime.orig_dep_date.between(day_minus_one, day), StopTime.stop_id.in_(stops_ids), StopTime.sched_dep_dt >= start_dt, StopTime.sched_dep_dt < end_dt @@ -255,12 +256,12 @@ def get_stop_times(self, station: Station, line, start_time, day, ) if line != '': - raw_stop_times = raw_stop_times.filter(Trip.route_name == line) + raw_stop_times = raw_stop_times.filter(StopTime.route_name == line) if count: raw_stop_times = raw_stop_times \ - .group_by(Trip.route_name) \ - .order_by(func.count(Trip.route_name).desc()) + .group_by(StopTime.route_name) \ + .order_by(func.count(StopTime.route_name).desc()) else: raw_stop_times = raw_stop_times.order_by(StopTime.sched_dep_dt).limit(self.LIMIT).offset(offset_times) @@ -275,9 +276,9 @@ def get_stop_times(self, station: Station, line, start_time, day, dep_time = raw_stop_time.dep_time arr_time = raw_stop_time.arr_time stop_time = TripStopTime(station, raw_stop_time.origin_id, dep_time, None, 0, raw_stop_time.platform, - raw_stop_time.destination, raw_stop_time.trip_id, - raw_stop_time.route_name, arr_time=arr_time, - orig_dep_date=raw_stop_time.orig_dep_date) + raw_stop_time.destination, raw_stop_time.trip_id, + raw_stop_time.route_name, arr_time=arr_time, + orig_dep_date=raw_stop_time.orig_dep_date) stop_times.append(stop_time) return stop_times @@ -303,29 +304,33 @@ def get_stop_times_between_stops(self, dep_station: Station, arr_station: Statio if count: raw_stop_times = self.session.query( - Trip.route_name.label('route_name'), + d_stop_times.route_name.label('route_name'), ) else: raw_stop_times = self.session.query( d_stop_times.sched_arr_dt.label('d_arr_time'), d_stop_times.sched_dep_dt.label('d_dep_time'), - Trip.orig_id.label('origin_id'), - Trip.dest_text.label('destination'), - Trip.number.label('trip_id'), - Trip.orig_dep_date.label('orig_dep_date'), - Trip.route_name.label('route_name'), + d_stop_times.orig_id.label('origin_id'), + d_stop_times.dest_text.label('destination'), + d_stop_times.number.label('trip_id'), + d_stop_times.orig_dep_date.label('orig_dep_date'), + d_stop_times.route_name.label('route_name'), d_stop_times.platform.label('d_platform'), a_stop_times.sched_dep_dt.label('a_dep_time'), a_stop_times.sched_arr_dt.label('a_arr_time'), a_stop_times.platform.label('a_platform') ) + day_minus_one = day - timedelta(days=1) + raw_stop_times = raw_stop_times \ .select_from(d_stop_times) \ - .join(a_stop_times, d_stop_times.trip_id == a_stop_times.trip_id) \ - .join(Trip, d_stop_times.trip_id == Trip.id) \ + .join(a_stop_times, and_(d_stop_times.number == a_stop_times.number, + d_stop_times.orig_dep_date == a_stop_times.orig_dep_date, + d_stop_times.source == a_stop_times.source)) \ .filter( and_( + d_stop_times.orig_dep_date.between(day_minus_one, day), d_stop_times.stop_id.in_(dep_stops_ids), d_stop_times.sched_dep_dt >= start_dt, d_stop_times.sched_dep_dt < end_dt, @@ -335,10 +340,11 @@ def get_stop_times_between_stops(self, dep_station: Station, arr_station: Statio ) if line != '': - raw_stop_times = raw_stop_times.filter(Trip.route_name == line) + raw_stop_times = raw_stop_times.filter(d_stop_times.route_name == line) if count: - raw_stop_times = raw_stop_times.group_by(Trip.route_name).order_by(func.count(Trip.route_name).desc()) + raw_stop_times = raw_stop_times.group_by(d_stop_times.route_name).order_by( + func.count(d_stop_times.route_name).desc()) else: raw_stop_times = raw_stop_times.order_by( d_stop_times.sched_dep_dt @@ -377,17 +383,23 @@ def sync_stations_db(self, new_stations: list[Station], new_stops: list[Stop] = for station in new_stations: stmt = insert(Station).values(id=station.id, name=station.name, lat=station.lat, lon=station.lon, - ids=station.ids, source=self.name, times_count=station.times_count) + ids=station.ids, source=self.name, times_count=station.times_count, + active=True) stmt = stmt.on_conflict_do_update( index_elements=['id'], set_={'name': station.name, 'lat': station.lat, 'lon': station.lon, 'ids': station.ids, - 'source': self.name, 'times_count': station.times_count} + 'source': self.name, 'times_count': station.times_count, 'active': True} ) self.session.execute(stmt) - for station in self.session.scalars(select(Station).filter_by(source=self.name)).all(): + for station in self.session.scalars(select(Station).filter_by(source=self.name, active=True)).all(): if station.id not in station_codes: - self.session.delete(station) + # set station as inactive and set all stops as inactive + station.active = False + for stop in station.stops: + stop.active = False + + self.session.commit() @@ -396,28 +408,28 @@ def sync_stations_db(self, new_stations: list[Station], new_stops: list[Stop] = if new_stops: for stop in new_stops: stmt = insert(Stop).values(id=stop.id, platform=stop.platform, lat=stop.lat, lon=stop.lon, - station_id=stop.station_id, source=self.name) + station_id=stop.station_id, source=self.name, active=True) stmt = stmt.on_conflict_do_update( index_elements=['id'], set_={'platform': stop.platform, 'lat': stop.lat, 'lon': stop.lon, - 'station_id': stop.station_id, 'source': self.name} + 'station_id': stop.station_id, 'source': self.name, 'active': True} ) self.session.execute(stmt) else: for station in new_stations: stmt = insert(Stop).values(id=station.id, platform=None, lat=station.lat, lon=station.lon, - station_id=station.id, source=self.name) + station_id=station.id, source=self.name, active=True) stmt = stmt.on_conflict_do_update( index_elements=['id'], set_={'platform': None, 'lat': station.lat, 'lon': station.lon, - 'station_id': station.id, 'source': self.name} + 'station_id': station.id, 'source': self.name, 'active': True} ) self.session.execute(stmt) - # Stops with stations not in station_codes are deleted through cascade - for stop in self.session.scalars(select(Stop).filter(Stop.station_id.in_(station_codes))).all(): + # Stops with stations not in station_codes are set as inactive + for stop in self.session.scalars(select(Stop).filter(Stop.station_id.in_(station_codes), Stop.active is True)).all(): if stop.id not in stop_ids: - self.session.delete(stop) + stop.active = False self.session.commit() @@ -455,44 +467,33 @@ def search_lines(self, name): raise NotImplementedError def get_source_stations(self) -> list[Station]: - return self.session.scalars(select(Station).filter_by(source=self.name)).all() + return self.session.scalars(select(Station).filter_by(source=self.name, active=True)).all() def upload_trip_stop_time_to_postgres(self, stop_time: TripStopTime): - train = self.session.query(Trip).filter_by( - number=stop_time.trip_id, - orig_dep_date=stop_time.orig_dep_date, - source=self.name - ).first() - - if not train: - train = Trip(orig_id=stop_time.origin_id, dest_text=stop_time.destination, - number=stop_time.trip_id, orig_dep_date=stop_time.orig_dep_date, - route_name=stop_time.route_name, source=self.name) - self.session.add(train) - self.session.commit() - stop_id = self.name + '_' + stop_time.station.id if self.name != 'treni' else stop_time.station.id - stop_time_db = self.session.query(StopTime).filter_by(trip_id=train.id, stop_id=stop_id).first() - if stop_time_db: - if stop_time_db.platform != stop_time.platform: - stop_time_db.platform = stop_time.platform - self.session.commit() - else: - new_stop_time = StopTime(trip_id=train.id, stop_id=stop_id, sched_arr_dt=stop_time.arr_time, - sched_dep_dt=stop_time.dep_time, platform=stop_time.platform) - self.session.add(new_stop_time) - self.session.commit() + stmt = insert(StopTime).values(stop_id=stop_id, sched_arr_dt=stop_time.arr_time, + sched_dep_dt=stop_time.dep_time, platform=stop_time.platform, + orig_id=stop_time.origin_id, dest_text=stop_time.destination, + number=stop_time.trip_id, orig_dep_date=stop_time.orig_dep_date, + route_name=stop_time.route_name, source=self.name) + + stmt = stmt.on_conflict_do_update( + index_elements=['stop_id', 'number', 'orig_dep_date', 'source'], + set_={'platform': stop_time.platform} + ) + + self.session.execute(stmt) + self.session.commit() def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: trip_id = int(trip_id) - query = select(StopTime, Trip, Stop) \ - .join(StopTime.trip) \ + query = select(StopTime, Stop) \ .join(StopTime.stop) \ .filter( and_( - Trip.number == trip_id, - Trip.orig_dep_date == day.isoformat() + StopTime.number == trip_id, + StopTime.orig_dep_date == day.isoformat() )) \ .order_by(StopTime.sched_dep_dt) @@ -500,13 +501,11 @@ def get_stops_from_trip_id(self, trip_id, day: date) -> list[BaseStopTime]: stop_times = [] for result in results: - stop_time = TripStopTime(result.Stop, result.Trip.orig_id, result.StopTime.sched_dep_dt, - None, 0, - result.StopTime.platform, result.Trip.dest_text, trip_id, - result.Trip.route_name, - result.StopTime.sched_arr_dt, result.Trip.orig_dep_date) + stop_time = TripStopTime(result.Stop, result.StopTime.orig_id, result.StopTime.sched_dep_dt, + None, 0, + result.StopTime.platform, result.StopTime.dest_text, trip_id, + result.StopTime.route_name, + result.StopTime.sched_arr_dt, result.StopTime.orig_dep_date) stop_times.append(stop_time) return stop_times - - diff --git a/server/trenitalia/source.py b/server/trenitalia/source.py index 859666e..7aa1fd9 100644 --- a/server/trenitalia/source.py +++ b/server/trenitalia/source.py @@ -27,8 +27,8 @@ def __init__(self, session, typesense, location='', force_update_stations=False) self.location = location super().__init__('treni', '🚆', session, typesense) - if force_update_stations or self.session.query(Station).filter_by(source=self.name).count() == 0 or \ - self.session.query(Stop).filter_by(source=self.name).count() == 0: + if force_update_stations or self.session.query(Station).filter_by(source=self.name, active=True).count() == 0 or \ + self.session.query(Stop).filter_by(source=self.name, active=True).count() == 0: current_dir = os.path.abspath(os.path.dirname(__file__)) datadir = os.path.abspath(current_dir + '/data') @@ -43,7 +43,7 @@ def __init__(self, session, typesense, location='', force_update_stations=False) self.sync_stations_db(new_stations) def save_data(self): - stations = self.session.scalars(select(Station).filter_by(source=self.name)).all() + stations = self.session.scalars(select(Station).filter_by(source=self.name, active=True)).all() total_times_count = 0 times_count = []