Skip to content

Commit

Permalink
code refactoring: implemented a with_db_session decorator to streamli…
Browse files Browse the repository at this point in the history
…ne session management.
  • Loading branch information
qcdyx committed Nov 20, 2024
1 parent 83268aa commit 8789ae3
Show file tree
Hide file tree
Showing 14 changed files with 420 additions and 449 deletions.
386 changes: 196 additions & 190 deletions api/src/database/database.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions api/src/feeds/impl/datasets_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Gtfsdataset,
Feed,
Expand Down Expand Up @@ -93,9 +93,10 @@ def apply_bounding_filtering(
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))

@staticmethod
def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> List[GtfsDataset]:
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
# Results are sorted by stable_id because Database.select(group_by=) requires it so
dataset_groups = Database().select(
session=session,
query=query.order_by(Gtfsdataset.stable_id),
limit=limit,
offset=offset,
Expand All @@ -109,15 +110,13 @@ def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> Li
gtfs_datasets.append(GtfsDatasetImpl.from_orm(dataset_objects[0]))
return gtfs_datasets

def get_dataset_gtfs(
self,
id: str,
) -> GtfsDataset:
@with_db_session
def get_dataset_gtfs(self, id: str, db_session: Session) -> GtfsDataset:
"""Get the specified dataset from the Mobility Database."""

query = DatasetsApiImpl.create_dataset_query().filter(Gtfsdataset.stable_id == id)

if (ret := DatasetsApiImpl.get_datasets_gtfs(query)) and len(ret) == 1:
if (ret := DatasetsApiImpl.get_datasets_gtfs(query, db_session)) and len(ret) == 1:
return ret[0]
else:
raise_http_error(404, dataset_not_found.format(id))
93 changes: 41 additions & 52 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import List, Union, TypeVar

from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, Session
from sqlalchemy.orm.query import Query

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Feed,
Gtfsdataset,
Expand Down Expand Up @@ -59,20 +59,19 @@ class FeedsApiImpl(BaseFeedsApi):

APIFeedType = Union[BasicFeed, GtfsFeed, GtfsRTFeed]

def get_feed(
self,
id: str,
) -> BasicFeed:
@with_db_session
def get_feed(self, id: str, db_session: Session) -> BasicFeed:
"""Get the specified feed from the Mobility Database."""
feed = (
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
.filter(Database().get_query_model(db_session, Feed))
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
.filter(
or_(
Feed.operational_status == None, # noqa: E711
Feed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
# Allow all feeds to be returned if the user is not restricted
not is_user_email_restricted(),
)
)
.first()
Expand All @@ -82,19 +81,15 @@ def get_feed(
else:
raise_http_error(404, feed_not_found.format(id))

@with_db_session
def get_feeds(
self,
limit: int,
offset: int,
status: str,
provider: str,
producer_url: str,
self, limit: int, offset: int, status: str, provider: str, producer_url: str, db_session: Session
) -> List[BasicFeed]:
"""Get some (or all) feeds from the Mobility Database."""
feed_filter = FeedFilter(
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
feed_query = feed_query.filter(
or_(
Expand All @@ -114,27 +109,25 @@ def get_feeds(
results = feed_query.all()
return [BasicFeedImpl.from_orm(feed) for feed in results]

def get_gtfs_feed(
self,
id: str,
) -> GtfsFeed:
@with_db_session
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
"""Get the specified gtfs feed from the Mobility Database."""
feed, translations = self._get_gtfs_feed(id)
feed, translations = self._get_gtfs_feed(id, db_session)
if feed:
return GtfsFeedImpl.from_orm(feed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))

@staticmethod
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
results = (
FeedFilter(
stable_id=stable_id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
.filter(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Expand All @@ -156,6 +149,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
return results[0].Gtfsfeed, translations
return None, {}

@with_db_session
def get_gtfs_feed_datasets(
self,
gtfs_feed_id: str,
Expand All @@ -164,6 +158,7 @@ def get_gtfs_feed_datasets(
offset: int,
downloaded_after: str,
downloaded_before: str,
db_session: Session,
) -> List[GtfsDataset]:
"""Get a list of datasets related to a feed."""
if downloaded_before and not valid_iso_date(downloaded_before):
Expand All @@ -179,7 +174,7 @@ def get_gtfs_feed_datasets(
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.filter(Database().get_query_model(db_session, Gtfsfeed))
.filter(
or_(
Feed.operational_status == None, # noqa: E711
Expand All @@ -196,19 +191,20 @@ def get_gtfs_feed_datasets(
# Replace Z with +00:00 to make the datetime object timezone aware
# Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
query = GtfsDatasetFilter(
downloaded_at__lte=datetime.fromisoformat(downloaded_before.replace("Z", "+00:00"))
if downloaded_before
else None,
downloaded_at__gte=datetime.fromisoformat(downloaded_after.replace("Z", "+00:00"))
if downloaded_after
else None,
downloaded_at__lte=(
datetime.fromisoformat(downloaded_before.replace("Z", "+00:00")) if downloaded_before else None
),
downloaded_at__gte=(
datetime.fromisoformat(downloaded_after.replace("Z", "+00:00")) if downloaded_after else None
),
).filter(DatasetsApiImpl.create_dataset_query().filter(Feed.stable_id == gtfs_feed_id))

if latest:
query = query.filter(Gtfsdataset.latest)

return DatasetsApiImpl.get_datasets_gtfs(query, limit=limit, offset=offset)
return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)

@with_db_session
def get_gtfs_feeds(
self,
limit: int,
Expand All @@ -221,6 +217,7 @@ def get_gtfs_feeds(
dataset_latitudes: str,
dataset_longitudes: str,
bounding_filter_method: str,
db_session: Session,
) -> List[GtfsFeed]:
"""Get some (or all) GTFS feeds from the Mobility Database."""
gtfs_feed_filter = GtfsFeedFilter(
Expand All @@ -240,9 +237,7 @@ def get_gtfs_feeds(
).subquery()

feed_query = (
Database()
.get_session()
.query(Gtfsfeed)
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.id.in_(subquery))
.filter(
or_(
Expand All @@ -261,12 +256,10 @@ def get_gtfs_feeds(
.limit(limit)
.offset(offset)
)
return self._get_response(feed_query, GtfsFeedImpl)
return self._get_response(feed_query, GtfsFeedImpl, db_session)

def get_gtfs_rt_feed(
self,
id: str,
) -> GtfsRTFeed:
@with_db_session
def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
"""Get the specified GTFS Realtime feed from the Mobility Database."""
gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=id,
Expand All @@ -276,9 +269,7 @@ def get_gtfs_rt_feed(
location=None,
)
results = gtfs_rt_feed_filter.filter(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
db_session.query(Gtfsrealtimefeed, t_location_with_translations_en)
.filter(
or_(
Gtfsrealtimefeed.operational_status == None, # noqa: E711
Expand All @@ -301,6 +292,7 @@ def get_gtfs_rt_feed(
else:
raise_http_error(404, gtfs_rt_feed_not_found.format(id))

@with_db_session
def get_gtfs_rt_feeds(
self,
limit: int,
Expand All @@ -311,6 +303,7 @@ def get_gtfs_rt_feeds(
country_code: str,
subdivision_name: str,
municipality: str,
db_session: Session,
) -> List[GtfsRTFeed]:
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
entity_types_list = entity_types.split(",") if entity_types else None
Expand Down Expand Up @@ -342,9 +335,7 @@ def get_gtfs_rt_feeds(
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
).subquery()
feed_query = (
Database()
.get_session()
.query(Gtfsrealtimefeed)
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.id.in_(subquery))
.filter(
or_(
Expand All @@ -362,22 +353,20 @@ def get_gtfs_rt_feeds(
.limit(limit)
.offset(offset)
)
return self._get_response(feed_query, GtfsRTFeedImpl)
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)

@staticmethod
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -> List[T]:
"""Get the response for the feed query."""
results = feed_query.all()
location_translations = get_feeds_location_translations(results)
location_translations = get_feeds_location_translations(results, db_session)
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
return list({feed.id: feed for feed in response}.values())

def get_gtfs_feed_gtfs_rt_feeds(
self,
id: str,
) -> List[GtfsRTFeed]:
@with_db_session
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
"""Get a list of GTFS Realtime related to a GTFS feed."""
feed, translations = self._get_gtfs_feed(id)
feed, translations = self._get_gtfs_feed(id, db_session)
if feed:
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
else:
Expand Down
8 changes: 6 additions & 2 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List

from sqlalchemy import func, select
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database.sql_functions.unaccent import unaccent
from database_gen.sqlacodegen_models import t_feedsearch
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
Expand Down Expand Up @@ -83,6 +83,7 @@ def create_search_query(status: List[str], feed_id: str, data_type: str, search_
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status)
return query.order_by(rank_expression.desc())

@with_db_session
def search_feeds(
self,
limit: int,
Expand All @@ -91,15 +92,18 @@ def search_feeds(
feed_id: str,
data_type: str,
search_query: str,
db_session: "Session",
) -> SearchFeeds200Response:
"""Search feeds using full-text search on feed, location and provider's information."""
query = self.create_search_query(status, feed_id, data_type, search_query)
feed_rows = Database().select(
session=db_session,
query=query,
limit=limit,
offset=offset,
)
feed_total_count = Database().select(
session=db_session,
query=self.create_count_search_query(status, feed_id, data_type, search_query),
)
if feed_rows is None or feed_total_count is None:
Expand Down
11 changes: 8 additions & 3 deletions api/src/scripts/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Type
from typing import Type, TYPE_CHECKING

import pandas
from dotenv import load_dotenv
Expand All @@ -11,6 +11,9 @@
from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
from utils.logger import Logger

if TYPE_CHECKING:
from sqlalchemy.orm import Session

logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)

Expand Down Expand Up @@ -56,12 +59,14 @@ def __init__(self, filepaths):

self.filter_data()

def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
def query_feed_by_stable_id(
self, session: "Session", stable_id: str, data_type: str | None
) -> Gtfsrealtimefeed | Gtfsfeed | None:
"""
Query the feed by stable id
"""
model = self.get_model(data_type)
return self.db.session.query(model).filter(model.stable_id == stable_id).first()
return session.query(model).filter(model.stable_id == stable_id).first()

@staticmethod
def get_model(data_type: str | None) -> Type[Feed]:
Expand Down
16 changes: 9 additions & 7 deletions api/src/scripts/populate_db_gbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ def deprecate_feeds(self, deprecated_feeds):
if deprecated_feeds is None or deprecated_feeds.empty:
self.logger.info("No feeds to deprecate.")
return

self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).")
for index, row in deprecated_feeds.iterrows():
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
self.db.session.flush()
with self.db.start_db_session() as session:
for index, row in deprecated_feeds.iterrows():
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
session.flush()

def populate_db(self):
"""Populate the database with the GBFS feeds"""
Expand Down
Loading

0 comments on commit 8789ae3

Please sign in to comment.