Skip to content

Commit

Permalink
Update more code to use SQLAlchemy 2.0 API
Browse files Browse the repository at this point in the history
  • Loading branch information
smsearcy committed Jul 25, 2024
1 parent 5ebda4c commit d7a07ae
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 52 deletions.
30 changes: 15 additions & 15 deletions meshinfo/purge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import pendulum
import sqlalchemy as sa
from sqlalchemy import sql

from .historical import HistoricalStats
from .models import CollectorStat, Link, Node, NodeError, session_scope
Expand Down Expand Up @@ -30,27 +30,27 @@ def main(

with session_scope(dbsession_factory) as dbsession:
print()
total_node_count = dbsession.query(Node).count()
nodes = dbsession.query(Node).filter(Node.last_seen < cutoff).all()
total_node_count = dbsession.scalar(
sql.select(sql.func.count()).select_from(Node)
)
nodes = dbsession.scalars(sql.select(Node).where(Node.last_seen < cutoff)).all()
print(
f"Identified {len(nodes):,d} nodes to purge (out of {total_node_count:,d})."
)
node_ids = [node.id for node in nodes]
links = (
dbsession.query(Link)
.filter(
sa.or_(Link.source_id.in_(node_ids), Link.destination_id.in_(node_ids))
links = dbsession.scalars(
sql.select(Link).where(
sql.or_(Link.source_id.in_(node_ids), Link.destination_id.in_(node_ids))
)
.all()
)
).all()
print(f"Identified {len(links):,d} links to purge.")
stats = (
dbsession.query(CollectorStat)
.filter(CollectorStat.started_at < cutoff)
.all()
)
stats = dbsession.scalars(
sql.select(CollectorStat).where(CollectorStat.started_at < cutoff)
).all()
print(f"Identified {len(stats):,d} collector stats to purge.")
errors = dbsession.query(NodeError).filter(NodeError.timestamp < cutoff).all()
errors = dbsession.scalars(
sql.select(NodeError).where(NodeError.timestamp < cutoff)
).all()
print(f"Identified {len(errors):,d} node error details to purge.")
print()

Expand Down
20 changes: 12 additions & 8 deletions meshinfo/views/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
def overview(request: Request):
dbsession: Session = request.dbsession

node_count = (
dbsession.query(Node).filter(Node.status != NodeStatus.INACTIVE).count()
node_count = dbsession.scalar(
sql.select(sql.func.count())
.select_from(Node)
.where(Node.status != NodeStatus.INACTIVE)
)
link_count = (
dbsession.query(Link).filter(Link.status != LinkStatus.INACTIVE).count()
link_count = dbsession.scalar(
sql.select(sql.func.count())
.select_from(Node)
.where(Link.status != LinkStatus.INACTIVE)
)

# Get node counts by firmware version
Expand Down Expand Up @@ -56,16 +60,16 @@ def overview(request: Request):
)
band_stats = {band: count for band, count in band_results}

last_run = dbsession.execute(
last_run = dbsession.scalars(
sql.select(CollectorStat).order_by(sql.desc(CollectorStat.started_at)).limit(1)
).scalar()
).first()

node_errors_by_type: dict[str, list[NodeError]] = {}
if last_run:
errors = dbsession.execute(
errors = dbsession.scalars(
sql.select(NodeError).where(NodeError.timestamp == last_run.started_at)
)
for error in errors.scalars():
for error in errors:
node_errors_by_type.setdefault(str(error.error_type), []).append(error)

return {
Expand Down
23 changes: 11 additions & 12 deletions meshinfo/views/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from __future__ import annotations

from collections.abc import Iterator
from collections.abc import Iterator, Sequence

import attrs
import sqlalchemy as sa
from pyramid.request import Request
from pyramid.view import view_config
from sqlalchemy import sql
from sqlalchemy.orm import Session, aliased

from ..config import AppConfig
Expand Down Expand Up @@ -274,23 +274,22 @@ def network_map(request: Request):
def map_data(request: Request):
"""Generate node and link data as GeoJSON to be loaded into Leaflet."""
dbsession: Session = request.dbsession
node_query = dbsession.query(Node).filter(
node_query = sql.select(Node).where(
Node.status != NodeStatus.INACTIVE,
Node.latitude != sa.null(),
Node.longitude != sa.null(),
Node.latitude != sql.null(),
Node.longitude != sql.null(),
)
source_nodes = aliased(Node, node_query.subquery())
dest_nodes = aliased(Node, node_query.subquery())

nodes = node_query.all()
nodes = dbsession.scalars(node_query).all()

links = (
dbsession.query(Link)
links = dbsession.scalars(
sql.select(Link)
.join(source_nodes, Link.source_id == source_nodes.id)
.join(dest_nodes, Link.destination_id == dest_nodes.id)
.filter(Link.status != LinkStatus.INACTIVE)
.all()
)
.where(Link.status != LinkStatus.INACTIVE)
).all()

node_layers = {layer.key: layer for layer in _NODE_LAYERS}
link_layers = {layer.key: layer for layer in _LINK_LAYERS}
Expand Down Expand Up @@ -322,7 +321,7 @@ def _calc_hue(value: float, *, red: float, green: float) -> int:
return round(120 * percent)


def _dedupe_links(links: list[Link]) -> Iterator[Link]:
def _dedupe_links(links: Sequence[Link]) -> Iterator[Link]:
"""Filter out redundant tunnels and DTD links."""
# while it is unlikely that two nodes are connected by both types, this is safer
seen_tunnels = set()
Expand Down
10 changes: 5 additions & 5 deletions meshinfo/views/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pyramid.request import Request
from pyramid.response import Response
from pyramid.view import view_config, view_defaults
from sqlalchemy import sql
from sqlalchemy.orm import Session, subqueryload

from ..historical import HistoricalStats
Expand Down Expand Up @@ -65,12 +66,11 @@ def network_errors(request: Request):

marked_row = request.GET.get("highlight")

collector = (
dbsession.query(CollectorStat)
collector = dbsession.execute(
sql.select(CollectorStat)
.options(subqueryload(CollectorStat.node_errors))
.filter(CollectorStat.started_at == timestamp)
.one_or_none()
)
.where(CollectorStat.started_at == timestamp)
).scalar_one_or_none()
if collector is None:
raise HTTPNotFound(f"No collection statistics available for {timestamp}")

Expand Down
18 changes: 8 additions & 10 deletions meshinfo/views/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,22 @@ def node_preview(request: Request):
if not (node := dbsession.get(Node, node_id)):
raise HTTPNotFound("Sorry, the specified node could not be found")

query = (
dbsession.query(Link)
current_links = dbsession.scalars(
sql.select(Link)
.options(joinedload(Link.destination).load_only(Node.display_name))
.filter(
.where(
Link.source_id == node.id,
Link.status == LinkStatus.CURRENT,
)
)
current_links = query.all()
query = (
dbsession.query(Link)
).all()
recent_links = dbsession.scalars(
sql.select(Link)
.options(joinedload(Link.destination).load_only(Node.display_name))
.filter(
.where(
Link.source_id == node.id,
Link.status == LinkStatus.RECENT,
)
)
recent_links = query.all()
).all()

return {
"node": node,
Expand Down
7 changes: 5 additions & 2 deletions meshinfo/views/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pyramid.request import Request, Response
from pyramid.view import view_config, view_defaults
from sqlalchemy import sql
from sqlalchemy.orm import Session

from ..models import Node
Expand All @@ -16,8 +17,10 @@ def __init__(self, request: Request):
dbsession: Session = request.dbsession

# TODO: parameters to determine which nodes to return
query = dbsession.query(Node).filter(Node.status != NodeStatus.INACTIVE)
self.nodes: list[Node] = sorted(query.all(), key=attrgetter("name"))
nodes = dbsession.scalars(
sql.select(Node).where(Node.status != NodeStatus.INACTIVE)
).all()
self.nodes: list[Node] = sorted(nodes, key=attrgetter("name"))
self.request = request

@view_config(match_param="view=table", renderer="pages/nodes.jinja2")
Expand Down

0 comments on commit d7a07ae

Please sign in to comment.