Skip to content

Commit

Permalink
Merge v0.5.4 into production (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarrco authored Oct 15, 2023
2 parents 92ec352 + 7c34f55 commit b814dbf
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 38 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.db filter=lfs diff=lfs merge=lfs -text
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
lfs: true

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
97 changes: 61 additions & 36 deletions MuoVErsi/sources/GTFS/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,61 +44,74 @@ def get_latest_gtfs_version(transport_type):


class GTFS(Source):
def __init__(self, transport_type, emoji, session, typesense, gtfs_version=None, location='', dev=False):
def __init__(self, transport_type, emoji, session, typesense, gtfs_versions_range: tuple[int] = None, location='', dev=False, ref_dt: datetime = None):
super().__init__(transport_type[:3], emoji, session, typesense)
self.transport_type = transport_type
self.location = location
self.service_ids = {}

if gtfs_version:
self.gtfs_version = gtfs_version
self.download_and_convert_file()
if gtfs_versions_range:
init_version = gtfs_versions_range[0]
else:
gtfs_version = get_latest_gtfs_version(transport_type)
init_version = get_latest_gtfs_version(transport_type)

for try_version in range(gtfs_version, 0, -1):
fin_version = gtfs_versions_range[1] if gtfs_versions_range else 0

if not ref_dt:
ref_dt = datetime.today()

for try_version in range(init_version, fin_version-1, -1):
self.download_and_convert_file(try_version)
service_start_date = self.get_service_start_date(ref_dt, try_version)
if service_start_date and service_start_date <= ref_dt.date():
self.gtfs_version = try_version
self.download_and_convert_file()
if self.get_calendar_services():
break
break

self.con = self.connect_to_database()
self.next_service_start_date = service_start_date

if not hasattr(self, 'gtfs_version'):
raise Exception(f'No valid GTFS version found for {transport_type}')

self.con = self.connect_to_database(self.gtfs_version)

stops_clusters_uploaded = self.upload_stops_clusters_to_db()
logger.info('%s stops clusters uploaded: %s', self.name, stops_clusters_uploaded)

def file_path(self, ext):
def file_path(self, ext, gtfs_version):
current_dir = os.path.abspath(os.path.dirname(__file__))
parent_dir = os.path.abspath(current_dir + f"/../../../{self.location}")

return os.path.join(parent_dir, f'{self.transport_type}_{self.gtfs_version}.{ext}')
return os.path.join(parent_dir, f'{self.transport_type}_{gtfs_version}.{ext}')

def download_and_convert_file(self, force=False):
if os.path.isfile(self.file_path('db')) and not force:
def download_and_convert_file(self, gtfs_version, force=False):
if os.path.isfile(self.file_path('db', gtfs_version)) and not force:
return

url = f'https://actv.avmspa.it/sites/default/files/attachments/opendata/' \
f'{self.transport_type}/actv_{self.transport_type[:3]}_{self.gtfs_version}.zip'
f'{self.transport_type}/actv_{self.transport_type[:3]}_{gtfs_version}.zip'
ssl._create_default_https_context = ssl._create_unverified_context
file_path = self.file_path('zip')
file_path = self.file_path('zip', gtfs_version)
logger.info('Downloading %s to %s', url, file_path)
urllib.request.urlretrieve(url, file_path)

subprocess.run(["gtfs-import", "--gtfsPath", self.file_path('zip'), '--sqlitePath', self.file_path('db')])
subprocess.run(["gtfs-import", "--gtfsPath", self.file_path('zip', gtfs_version), '--sqlitePath', self.file_path('db', gtfs_version)])

def get_calendar_services(self) -> list[str]:
today_ymd = datetime.today().strftime('%Y%m%d')
weekday = datetime.today().strftime('%A').lower()
with self.connect_to_database() as con:
def get_service_start_date(self, ref_dt, gtfs_version) -> date:
weekday = ref_dt.strftime('%A').lower()
with self.connect_to_database(gtfs_version) as con:
con.row_factory = sqlite3.Row
cur = con.cursor()
services = cur.execute(
f'SELECT service_id FROM calendar WHERE {weekday} = 1 AND start_date <= ? AND end_date >= ?',
(today_ymd, today_ymd))

return list(set([service[0] for service in services.fetchall()]))

def connect_to_database(self) -> Connection:
return sqlite3.connect(self.file_path('db'))
service = cur.execute(
f'SELECT start_date FROM calendar WHERE {weekday} = 1 ORDER BY start_date ASC LIMIT 1'
).fetchone()

if not service:
return None

return datetime.strptime(str(service['start_date']), '%Y%m%d').date()

def connect_to_database(self, gtfs_version) -> Connection:
return sqlite3.connect(self.file_path('db', gtfs_version))

def get_all_stops(self) -> list[CStop]:
cur = self.con.cursor()
Expand Down Expand Up @@ -210,16 +223,28 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim
start_dt = datetime.combine(day, start_time)
end_dt = datetime.combine(day, end_time)

or_other_service = ''
today_service = f"(t.service_id in ({','.join(['?'] * len(today_service_ids))}) AND dep.departure_time >= ? AND dep.departure_time <= ?)"

if hasattr(self, 'next_service_start_date'):
if day >= self.next_service_start_date:
today_service = ''

yesterday_service = ''
yesterday_service_ids = []
if start_dt.hour < 6:
yesterday_service_ids = self.get_active_service_ids(day - timedelta(days=1))
if yesterday_service_ids:
or_other_service_ids = ','.join(['?'] * len(yesterday_service_ids))
or_other_service = f'OR (dep.departure_time >= ? AND t.service_id in ({or_other_service_ids}))'
yesterday_service = f'(dep.departure_time >= ? AND t.service_id in ({or_other_service_ids}))'
else:
start_dt = datetime.combine(day, time(6))

if yesterday_service == '' and today_service == '':
return []

if yesterday_service != '' and today_service != '':
today_service += ' OR '

select_elements = """
dep.departure_time as dep_time,
r.route_short_name as line,
Expand All @@ -246,15 +271,15 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim
INNER JOIN trips t ON dep.trip_id = t.trip_id
INNER JOIN routes r ON t.route_id = r.route_id
INNER JOIN stops s ON dep.stop_id = s.stop_id
WHERE ((t.service_id in ({','.join(['?'] * len(today_service_ids))}) AND dep.departure_time >= ?
AND dep.departure_time <= ?)
{or_other_service})
WHERE ({today_service} {yesterday_service})
LIMIT ? OFFSET ?
"""
params = ()

params = (*today_service_ids, start_dt.strftime('%H:%M'), end_dt.strftime('%H:%M'))
if today_service != '':
params += (*today_service_ids, start_dt.strftime('%H:%M'), end_dt.strftime('%H:%M'))

if or_other_service != '':
if yesterday_service != '':
# in the string add 24 hours to start_dt time
start_time_25 = f'{start_dt.hour + 24:02}:{start_dt.minute:02}'
params += (start_time_25, *yesterday_service_ids)
Expand Down
Binary file removed tests/data/navigazione_541.db
Binary file not shown.
3 changes: 3 additions & 0 deletions tests/data/navigazione_557.db
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/navigazione_558.db
Git LFS file not shown
55 changes: 54 additions & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,64 @@
from datetime import date, datetime, time
import pytest

from MuoVErsi.sources.GTFS import GTFS, get_clusters_of_stops, CCluster, CStop


@pytest.fixture
def db_file():
return GTFS('navigazione', '⛴️', None, None, 541, 'tests/data')
ref_dt = datetime(2023, 10, 15)
return GTFS('navigazione', '⛴️', None, None, (558, 557), 'tests/data', ref_dt=ref_dt)


def test_valid_gtfs():
_558_ref_df = datetime(2023, 10, 7)
_558_gtfs = GTFS('navigazione', '⛴️', None, None, (558, 557), 'tests/data', ref_dt=_558_ref_df)
assert _558_gtfs.gtfs_version == 558, 'invalid gtfs version'

_557_ref_dt = datetime(2023, 10, 6)
_557_gtfs = GTFS('navigazione', '⛴️', None, None, (558, 557), 'tests/data', ref_dt=_557_ref_dt)
assert _557_gtfs.gtfs_version == 557, 'invalid gtfs version'


def test_invalid_gtfs():
invalid_ref_df = datetime(2023, 9, 30)
with pytest.raises(Exception):
GTFS('navigazione', '⛴️', None, None, (558, 557), 'tests/data', ref_dt=invalid_ref_df)


def test_zero_stop_times_for_next_service():
db_file = GTFS('navigazione', '⛴️', None, None, (558, 557), 'tests/data', ref_dt=datetime(2023, 10, 6))
next_service_date = date(2023, 10, 7)

# On the 2023-10-06 we already know that there will a new service starting on 2023-10-07
assert db_file.next_service_start_date == next_service_date

# We should get no stop times for the 2023-10-07 while using the 2023-10-06 service
end_time = time(23, 59, 59)

stop_times = db_file.get_sqlite_stop_times(next_service_date, time(1), end_time, 570, 0)
assert len(stop_times) == 569, 'there should be only night routes serviced from 2023-10-06'

stop_times = db_file.get_sqlite_stop_times(next_service_date, time(8), end_time, 1, 0)
assert len(stop_times) == 0, 'there should be no stop times for the 2023-10-07 while using the 2023-10-06 service'


def test_normal_stop_times_for_current_service():
ref_dt = datetime(2023, 10, 7)
db_file = GTFS('navigazione', '⛴️', None, None, (558, 557), 'tests/data', ref_dt=ref_dt)

# On the 2023-10-06 we already know that there will a new service starting on 2023-10-07
assert not hasattr(db_file, 'next_service_start_date')

# We should get no stop times for the 2023-10-07 while using the 2023-10-06 service
end_time = time(23, 59, 59)

stop_times = db_file.get_sqlite_stop_times(ref_dt.date(), time(1), end_time, 570, 0)
len_stop_times = len(stop_times)
assert len(stop_times) > 569, 'there should be only night routes serviced from 2023-10-06'

stop_times = db_file.get_sqlite_stop_times(ref_dt.date(), time(8), end_time, 1, 0)
assert len(stop_times) > 0, 'there should be normal stop times for the 2023-10-07'


@pytest.fixture
Expand Down

0 comments on commit b814dbf

Please sign in to comment.