diff --git a/meshinfo/purge.py b/meshinfo/purge.py index 32ca51e..dd41bff 100644 --- a/meshinfo/purge.py +++ b/meshinfo/purge.py @@ -1,7 +1,7 @@ from __future__ import annotations import pendulum -import sqlalchemy as sa +from sqlalchemy import sql from .historical import HistoricalStats from .models import CollectorStat, Link, Node, NodeError, session_scope @@ -30,27 +30,27 @@ def main( with session_scope(dbsession_factory) as dbsession: print() - total_node_count = dbsession.query(Node).count() - nodes = dbsession.query(Node).filter(Node.last_seen < cutoff).all() + total_node_count = dbsession.scalar( + sql.select(sql.func.count()).select_from(Node) + ) + nodes = dbsession.scalars(sql.select(Node).where(Node.last_seen < cutoff)).all() print( f"Identified {len(nodes):,d} nodes to purge (out of {total_node_count:,d})." ) node_ids = [node.id for node in nodes] - links = ( - dbsession.query(Link) - .filter( - sa.or_(Link.source_id.in_(node_ids), Link.destination_id.in_(node_ids)) + links = dbsession.scalars( + sql.select(Link).where( + sql.or_(Link.source_id.in_(node_ids), Link.destination_id.in_(node_ids)) ) - .all() - ) + ).all() print(f"Identified {len(links):,d} links to purge.") - stats = ( - dbsession.query(CollectorStat) - .filter(CollectorStat.started_at < cutoff) - .all() - ) + stats = dbsession.scalars( + sql.select(CollectorStat).where(CollectorStat.started_at < cutoff) + ).all() print(f"Identified {len(stats):,d} collector stats to purge.") - errors = dbsession.query(NodeError).filter(NodeError.timestamp < cutoff).all() + errors = dbsession.scalars( + sql.select(NodeError).where(NodeError.timestamp < cutoff) + ).all() print(f"Identified {len(errors):,d} node error details to purge.") print() diff --git a/meshinfo/views/home.py b/meshinfo/views/home.py index 3dd2456..86b8dd1 100644 --- a/meshinfo/views/home.py +++ b/meshinfo/views/home.py @@ -16,11 +16,15 @@ def overview(request: Request): dbsession: Session = request.dbsession - node_count = ( - dbsession.query(Node).filter(Node.status != NodeStatus.INACTIVE).count() + node_count = dbsession.scalar( + sql.select(sql.func.count()) + .select_from(Node) + .where(Node.status != NodeStatus.INACTIVE) ) - link_count = ( - dbsession.query(Link).filter(Link.status != LinkStatus.INACTIVE).count() + link_count = dbsession.scalar( + sql.select(sql.func.count()) + .select_from(Node) + .where(Link.status != LinkStatus.INACTIVE) ) # Get node counts by firmware version @@ -56,16 +60,16 @@ def overview(request: Request): ) band_stats = {band: count for band, count in band_results} - last_run = dbsession.execute( + last_run = dbsession.scalars( sql.select(CollectorStat).order_by(sql.desc(CollectorStat.started_at)).limit(1) - ).scalar() + ).first() node_errors_by_type: dict[str, list[NodeError]] = {} if last_run: - errors = dbsession.execute( + errors = dbsession.scalars( sql.select(NodeError).where(NodeError.timestamp == last_run.started_at) ) - for error in errors.scalars(): + for error in errors: node_errors_by_type.setdefault(str(error.error_type), []).append(error) return { diff --git a/meshinfo/views/map.py b/meshinfo/views/map.py index 505c711..393f98c 100644 --- a/meshinfo/views/map.py +++ b/meshinfo/views/map.py @@ -2,12 +2,12 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Iterator, Sequence import attrs -import sqlalchemy as sa from pyramid.request import Request from pyramid.view import view_config +from sqlalchemy import sql from sqlalchemy.orm import Session, aliased from ..config import AppConfig @@ -274,23 +274,22 @@ def network_map(request: Request): def map_data(request: Request): """Generate node and link data as GeoJSON to be loaded into Leaflet.""" dbsession: Session = request.dbsession - node_query = dbsession.query(Node).filter( + node_query = sql.select(Node).where( Node.status != NodeStatus.INACTIVE, - Node.latitude != sa.null(), - Node.longitude != sa.null(), + Node.latitude != sql.null(), + Node.longitude != sql.null(), ) source_nodes = aliased(Node, node_query.subquery()) dest_nodes = aliased(Node, node_query.subquery()) - nodes = node_query.all() + nodes = dbsession.scalars(node_query).all() - links = ( - dbsession.query(Link) + links = dbsession.scalars( + sql.select(Link) .join(source_nodes, Link.source_id == source_nodes.id) .join(dest_nodes, Link.destination_id == dest_nodes.id) - .filter(Link.status != LinkStatus.INACTIVE) - .all() - ) + .where(Link.status != LinkStatus.INACTIVE) + ).all() node_layers = {layer.key: layer for layer in _NODE_LAYERS} link_layers = {layer.key: layer for layer in _LINK_LAYERS} @@ -322,7 +321,7 @@ def _calc_hue(value: float, *, red: float, green: float) -> int: return round(120 * percent) -def _dedupe_links(links: list[Link]) -> Iterator[Link]: +def _dedupe_links(links: Sequence[Link]) -> Iterator[Link]: """Filter out redundant tunnels and DTD links.""" # while it is unlikely that two nodes are connected by both types, this is safer seen_tunnels = set() diff --git a/meshinfo/views/network.py b/meshinfo/views/network.py index 1569cc0..e505fed 100644 --- a/meshinfo/views/network.py +++ b/meshinfo/views/network.py @@ -3,6 +3,7 @@ from pyramid.request import Request from pyramid.response import Response from pyramid.view import view_config, view_defaults +from sqlalchemy import sql from sqlalchemy.orm import Session, subqueryload from ..historical import HistoricalStats @@ -65,12 +66,11 @@ def network_errors(request: Request): marked_row = request.GET.get("highlight") - collector = ( - dbsession.query(CollectorStat) + collector = dbsession.execute( + sql.select(CollectorStat) .options(subqueryload(CollectorStat.node_errors)) - .filter(CollectorStat.started_at == timestamp) - .one_or_none() - ) + .where(CollectorStat.started_at == timestamp) + ).scalar_one_or_none() if collector is None: raise HTTPNotFound(f"No collection statistics available for {timestamp}") diff --git a/meshinfo/views/node.py b/meshinfo/views/node.py index 595aca6..348e29f 100644 --- a/meshinfo/views/node.py +++ b/meshinfo/views/node.py @@ -86,24 +86,22 @@ def node_preview(request: Request): if not (node := dbsession.get(Node, node_id)): raise HTTPNotFound("Sorry, the specified node could not be found") - query = ( - dbsession.query(Link) + current_links = dbsession.scalars( + sql.select(Link) .options(joinedload(Link.destination).load_only(Node.display_name)) - .filter( + .where( Link.source_id == node.id, Link.status == LinkStatus.CURRENT, ) - ) - current_links = query.all() - query = ( - dbsession.query(Link) + ).all() + recent_links = dbsession.scalars( + sql.select(Link) .options(joinedload(Link.destination).load_only(Node.display_name)) - .filter( + .where( Link.source_id == node.id, Link.status == LinkStatus.RECENT, ) - ) - recent_links = query.all() + ).all() return { "node": node, diff --git a/meshinfo/views/nodes.py b/meshinfo/views/nodes.py index 2b890d4..2e7210b 100644 --- a/meshinfo/views/nodes.py +++ b/meshinfo/views/nodes.py @@ -4,6 +4,7 @@ from pyramid.request import Request, Response from pyramid.view import view_config, view_defaults +from sqlalchemy import sql from sqlalchemy.orm import Session from ..models import Node @@ -16,8 +17,10 @@ def __init__(self, request: Request): dbsession: Session = request.dbsession # TODO: parameters to determine which nodes to return - query = dbsession.query(Node).filter(Node.status != NodeStatus.INACTIVE) - self.nodes: list[Node] = sorted(query.all(), key=attrgetter("name")) + nodes = dbsession.scalars( + sql.select(Node).where(Node.status != NodeStatus.INACTIVE) + ).all() + self.nodes: list[Node] = sorted(nodes, key=attrgetter("name")) self.request = request @view_config(match_param="view=table", renderer="pages/nodes.jinja2")