From 99d143b161c2e1eafbfd346b21afb2039b989b1e Mon Sep 17 00:00:00 2001 From: Scott Searcy Date: Wed, 8 May 2024 14:41:29 -0700 Subject: [PATCH] Convert models to new ORM style Some queries have been updated, but not all. --- meshinfo/collector.py | 6 +- meshinfo/historical.py | 4 +- meshinfo/models/collector.py | 58 +++++++++--------- meshinfo/models/link.py | 64 +++++++++++--------- meshinfo/models/meta.py | 14 +++-- meshinfo/models/node.py | 111 ++++++++++++++++------------------- meshinfo/views/home.py | 44 +++++++------- meshinfo/views/map.py | 12 ++-- meshinfo/views/node.py | 44 +++++++------- pdm.lock | 4 +- pyproject.toml | 1 - 11 files changed, 182 insertions(+), 180 deletions(-) diff --git a/meshinfo/collector.py b/meshinfo/collector.py index 240c548..4a513c4 100644 --- a/meshinfo/collector.py +++ b/meshinfo/collector.py @@ -291,6 +291,7 @@ def save_nodes( """ if count is None: count = defaultdict(int) + timestamp = pendulum.now() node_models = [] for node in nodes: count["nodes: total"] += 1 @@ -311,7 +312,7 @@ def save_nodes( logger.debug("Updated node in database", model=model) node_models.append(model) - model.last_seen = pendulum.now() + model.last_seen = timestamp model.status = NodeStatus.ACTIVE for model_attr, node_attr in MODEL_TO_SYSINFO_ATTRS.items(): @@ -393,6 +394,7 @@ def save_links( for node in dbsession.query(Node).filter(Node.status == NodeStatus.ACTIVE) } + timestamp = pendulum.now() link_models = [] for link in links: count["links: total"] += 1 @@ -425,7 +427,7 @@ def save_links( link_models.append(model) model.status = LinkStatus.CURRENT - model.last_seen = pendulum.now() + model.last_seen = timestamp for attribute in [ "type", diff --git a/meshinfo/historical.py b/meshinfo/historical.py index 59a5205..b63fbb2 100644 --- a/meshinfo/historical.py +++ b/meshinfo/historical.py @@ -112,9 +112,7 @@ def update_node_stats(self, node: Node) -> bool: node.link_count, len(node.services), node.up_time_seconds, - node.load_averages[0] - if isinstance(node.load_averages, list) - else None, + node.load_averages[0] if node.load_averages is not None else None, node.radio_link_count, node.dtd_link_count, node.tunnel_link_count, diff --git a/meshinfo/models/collector.py b/meshinfo/models/collector.py index 878a58b..8e8efc3 100644 --- a/meshinfo/models/collector.py +++ b/meshinfo/models/collector.py @@ -1,9 +1,25 @@ +from __future__ import annotations + import pendulum -import sqlalchemy as sa -from sqlalchemy.orm import relationship +from sqlalchemy import JSON, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship from ..poller import PollingError -from .meta import Base, PDateTime +from .meta import Base + + +class NodeError(Base): + """Information about nodes with errors during collection.""" + + __tablename__ = "node_error" + + timestamp: Mapped[pendulum.DateTime] = mapped_column( + ForeignKey("collector_stat.started_at"), primary_key=True + ) + ip_address: Mapped[str] = mapped_column(String(15), primary_key=True) + dns_name: Mapped[str] = mapped_column(String(70)) + error_type: Mapped[PollingError] + details: Mapped[str] class CollectorStat(Base): @@ -11,32 +27,18 @@ class CollectorStat(Base): __tablename__ = "collector_stat" - started_at = sa.Column(PDateTime(), primary_key=True) - finished_at = sa.Column(PDateTime(), default=pendulum.now, nullable=False) - node_count = sa.Column(sa.Integer, nullable=False) - link_count = sa.Column(sa.Integer, nullable=False) - error_count = sa.Column(sa.Integer, nullable=False) - polling_duration = sa.Column(sa.Float, nullable=False) - total_duration = sa.Column(sa.Float, nullable=False) - other_stats = sa.Column(sa.JSON, nullable=False) - - node_errors = relationship( - "NodeError", foreign_keys="NodeError.timestamp", cascade="all, delete-orphan" + started_at: Mapped[pendulum.DateTime] = mapped_column(primary_key=True) + finished_at: Mapped[pendulum.DateTime] = mapped_column(default=pendulum.now) + node_count: Mapped[int] + link_count: Mapped[int] + error_count: Mapped[int] + polling_duration: Mapped[float] + total_duration: Mapped[float] + other_stats: Mapped[dict] = mapped_column(JSON) + + node_errors: Mapped[list[NodeError]] = relationship( + foreign_keys="NodeError.timestamp", cascade="all, delete-orphan" ) def __repr__(self): return f"" - - -class NodeError(Base): - """Information about nodes with errors during collection.""" - - __tablename__ = "node_error" - - timestamp = sa.Column( - PDateTime(), sa.ForeignKey("collector_stat.started_at"), primary_key=True - ) - ip_address = sa.Column(sa.String(15), primary_key=True) - dns_name = sa.Column(sa.String(70), nullable=False) - error_type = sa.Column(sa.Enum(PollingError, native_enum=False), nullable=False) - details = sa.Column(sa.UnicodeText, nullable=False) diff --git a/meshinfo/models/link.py b/meshinfo/models/link.py index ff80f8a..02899b3 100644 --- a/meshinfo/models/link.py +++ b/meshinfo/models/link.py @@ -1,9 +1,16 @@ +"""Database model(s) for representing links between nodes.""" + +from typing import TYPE_CHECKING, Optional + import pendulum -import sqlalchemy as sa -from sqlalchemy.orm import relationship +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship from ..types import LinkId, LinkStatus, LinkType -from .meta import Base, PDateTime +from .meta import Base + +if TYPE_CHECKING: + from .node import Node class Link(Base): @@ -11,35 +18,34 @@ class Link(Base): __tablename__ = "link" - source_id = sa.Column(sa.Integer, sa.ForeignKey("node.node_id"), primary_key=True) - destination_id = sa.Column( - sa.Integer, sa.ForeignKey("node.node_id"), primary_key=True + source_id: Mapped[int] = mapped_column(ForeignKey("node.node_id"), primary_key=True) + destination_id: Mapped[int] = mapped_column( + ForeignKey("node.node_id"), primary_key=True ) - type = sa.Column(sa.Enum(LinkType, native_enum=False), primary_key=True) - status = sa.Column(sa.Enum(LinkStatus, native_enum=False), nullable=False) - last_seen = sa.Column(PDateTime(), nullable=False, default=pendulum.now) - - olsr_cost = sa.Column(sa.Float) - distance = sa.Column(sa.Float) - bearing = sa.Column(sa.Float) - - signal = sa.Column(sa.Float) - noise = sa.Column(sa.Float) - tx_rate = sa.Column(sa.Float) - rx_rate = sa.Column(sa.Float) - quality = sa.Column(sa.Float) - neighbor_quality = sa.Column(sa.Float) - - created_at = sa.Column(PDateTime(), default=pendulum.now, nullable=False) - last_updated_at = sa.Column( - PDateTime(), - default=pendulum.now, - onupdate=pendulum.now, - nullable=False, + type: Mapped[LinkType] = mapped_column(primary_key=True) + status: Mapped[LinkStatus] + last_seen: Mapped[pendulum.DateTime] = mapped_column(default=pendulum.now) + + olsr_cost: Mapped[float] + distance: Mapped[Optional[float]] + bearing: Mapped[Optional[float]] + + signal: Mapped[Optional[float]] + noise: Mapped[Optional[float]] + tx_rate: Mapped[Optional[float]] + rx_rate: Mapped[Optional[float]] + quality: Mapped[Optional[float]] + neighbor_quality: Mapped[Optional[float]] + + created_at: Mapped[pendulum.DateTime] = mapped_column(default=pendulum.now) + last_updated_at: Mapped[pendulum.DateTime] = mapped_column( + default=pendulum.now, onupdate=pendulum.now ) - source = relationship("Node", foreign_keys="Link.source_id", back_populates="links") - destination = relationship("Node", foreign_keys="Link.destination_id") + source: Mapped["Node"] = relationship( + foreign_keys="Link.source_id", back_populates="links" + ) + destination: Mapped["Node"] = relationship(foreign_keys="Link.destination_id") @property def signal_noise_ratio(self): diff --git a/meshinfo/models/meta.py b/meshinfo/models/meta.py index da1d4bf..b447b2e 100644 --- a/meshinfo/models/meta.py +++ b/meshinfo/models/meta.py @@ -1,7 +1,7 @@ from __future__ import annotations import pendulum -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.schema import MetaData from sqlalchemy.types import TIMESTAMP, TypeDecorator @@ -17,14 +17,13 @@ } metadata = MetaData(naming_convention=NAMING_CONVENTION) -Base = declarative_base(metadata=metadata) class PDateTime(TypeDecorator): """SQLAlchemy type to wrap `pendulum.datetime` instead of `datetime.datetime`.""" impl = TIMESTAMP(timezone=True) - cache_ok = False + cache_ok = True def process_bind_param(self, value, dialect): if value is not None: @@ -37,5 +36,12 @@ def process_bind_param(self, value, dialect): def process_result_value(self, value, dialect): if value is not None: value = pendulum.instance(value) - return value + + +class Base(DeclarativeBase): + metadata = metadata + + type_annotation_map = { + pendulum.DateTime: PDateTime, + } diff --git a/meshinfo/models/node.py b/meshinfo/models/node.py index cae566c..4f0246b 100644 --- a/meshinfo/models/node.py +++ b/meshinfo/models/node.py @@ -1,19 +1,14 @@ +from typing import TYPE_CHECKING, Optional + import pendulum -from sqlalchemy import ( - JSON, - Boolean, - Column, - Enum, - Float, - Index, - Integer, - String, - Unicode, -) -from sqlalchemy.orm import relationship +from sqlalchemy import JSON, Enum, Index, String, Unicode +from sqlalchemy.orm import Mapped, mapped_column, relationship from ..types import Band, NodeStatus -from .meta import Base, PDateTime +from .meta import Base + +if TYPE_CHECKING: + from .link import Link class Node(Base): @@ -21,64 +16,62 @@ class Node(Base): __tablename__ = "node" - id = Column("node_id", Integer, primary_key=True) - name = Column(String(70), nullable=False) - status = Column(Enum(NodeStatus, native_enum=False), nullable=False) - display_name = Column(String(70), nullable=False) + id: Mapped[int] = mapped_column("node_id", primary_key=True) + name: Mapped[str] = mapped_column(String(70)) + status: Mapped[NodeStatus] + display_name: Mapped[str] = mapped_column(String(70)) # store the wireless/primary IP address - ip_address = Column("wlan_ip", String(15), nullable=False) - description = Column(Unicode(1024), nullable=False) + ip_address: Mapped[str] = mapped_column("wlan_ip", String(15)) + description: Mapped[str] = mapped_column(Unicode(1024)) # store the MAC address (without colons) corresponding the primary interface - mac_address = Column("wlan_mac_address", String(12), nullable=False) - - last_seen = Column(PDateTime(), nullable=False) - - up_time = Column(String(25), nullable=False) - up_time_seconds = Column(Integer) - load_averages = Column(JSON()) - model = Column(String(50), nullable=False) - board_id = Column(String(50), nullable=False) - firmware_version = Column(String(50), nullable=False) - firmware_manufacturer = Column(String(100), nullable=False) - api_version = Column(String(5), nullable=False) - - latitude = Column(Float) - longitude = Column(Float) - grid_square = Column(String(20), nullable=False) - - ssid = Column(String(50), nullable=False) - channel = Column(String(50), nullable=False) - channel_bandwidth = Column(String(50), nullable=False) - band = Column( + mac_address: Mapped[str] = mapped_column("wlan_mac_address", String(12)) + + last_seen: Mapped[pendulum.DateTime] + + up_time: Mapped[str] = mapped_column(String(25)) + up_time_seconds: Mapped[Optional[int]] + load_averages: Mapped[Optional[list[float]]] = mapped_column(JSON) + model: Mapped[str] = mapped_column(String(50)) + board_id: Mapped[str] = mapped_column(String(50)) + firmware_version: Mapped[str] = mapped_column(String(50)) + firmware_manufacturer: Mapped[str] = mapped_column(String(100)) + api_version: Mapped[str] = mapped_column(String(5)) + + latitude: Mapped[Optional[float]] + longitude: Mapped[Optional[float]] + grid_square: Mapped[str] = mapped_column(String(20)) + + ssid: Mapped[str] = mapped_column(String(50)) + channel: Mapped[str] = mapped_column(String(50)) + channel_bandwidth: Mapped[str] = mapped_column(String(50)) + band: Mapped[Band] = mapped_column( Enum(Band, values_callable=lambda x: [e.value for e in x], native_enum=False), - nullable=False, ) - services = Column(JSON(), nullable=False) + services: Mapped[dict] = mapped_column(JSON) # As of API v1.10 this is irrelevant (because it is always enabled) # (probably worth deleting at some point in the future) - tunnel_installed = Column(Boolean(), nullable=False, default=True) - active_tunnel_count = Column(Integer(), nullable=False) - - link_count = Column(Integer()) - radio_link_count = Column(Integer()) - dtd_link_count = Column(Integer()) - tunnel_link_count = Column(Integer()) - - system_info = Column(JSON(), nullable=False) - - created_at = Column(PDateTime(), default=pendulum.now, nullable=False) - last_updated_at = Column( - PDateTime(), - default=pendulum.now, - onupdate=pendulum.now, - nullable=False, + tunnel_installed: Mapped[bool] = mapped_column(default=True) + active_tunnel_count: Mapped[int] + + link_count: Mapped[Optional[int]] + radio_link_count: Mapped[Optional[int]] + dtd_link_count: Mapped[Optional[int]] + tunnel_link_count: Mapped[Optional[int]] + + system_info: Mapped[dict] = mapped_column(JSON) + + created_at: Mapped[pendulum.DateTime] = mapped_column(default=pendulum.now) + last_updated_at: Mapped[pendulum.DateTime] = mapped_column( + default=pendulum.now, onupdate=pendulum.now ) - links = relationship("Link", foreign_keys="Link.source_id", back_populates="source") + links: Mapped["Link"] = relationship( + foreign_keys="Link.source_id", back_populates="source" + ) # Is this premature optimization? Index("idx_mac_name", mac_address, name) diff --git a/meshinfo/views/home.py b/meshinfo/views/home.py index 1d79d5b..3dd2456 100644 --- a/meshinfo/views/home.py +++ b/meshinfo/views/home.py @@ -3,9 +3,9 @@ import re from collections import defaultdict -import sqlalchemy as sa from pyramid.request import Request from pyramid.view import view_config +from sqlalchemy import sql from sqlalchemy.orm import Session from ..models import CollectorStat, Link, Node, NodeError @@ -24,48 +24,48 @@ def overview(request: Request): ) # Get node counts by firmware version - query = ( - dbsession.query( - Node.firmware_manufacturer, Node.firmware_version, sa.func.count(Node.id) + firmware_results = dbsession.execute( + sql.select( + Node.firmware_manufacturer, Node.firmware_version, sql.func.count(Node.id) ) - .filter(Node.status == NodeStatus.ACTIVE) + .where(Node.status == NodeStatus.ACTIVE) .group_by(Node.firmware_manufacturer, Node.firmware_version) ) firmware_stats: defaultdict[str, int] = defaultdict(int) - for manufacturer, version, count in query.all(): + for manufacturer, version, count in firmware_results: if manufacturer.lower() != "aredn": firmware_stats["Non-AREDN"] += 1 elif re.match(r"\d+\.\d+\.\d+\.\d+", version): firmware_stats[version] = count else: firmware_stats["Nightly"] += count + # Get node counts by API version - query = ( - dbsession.query(Node.api_version, sa.func.count(Node.id)) - .filter(Node.status == NodeStatus.ACTIVE) + api_results = dbsession.execute( + sql.select(Node.api_version, sql.func.count(Node.id)) + .where(Node.status == NodeStatus.ACTIVE) .group_by(Node.api_version) ) - api_version_stats = {version: count for version, count in query.all()} + api_version_stats = {version: count for version, count in api_results} + # Get node counts by band - query = ( - dbsession.query(Node.band, sa.func.count(Node.id)) - .filter(Node.status == NodeStatus.ACTIVE) + band_results = dbsession.execute( + sql.select(Node.band, sql.func.count(Node.id)) + .where(Node.status == NodeStatus.ACTIVE) .group_by(Node.band) ) - band_stats = {band: count for band, count in query.all()} + band_stats = {band: count for band, count in band_results} - last_run = ( - dbsession.query(CollectorStat) - .order_by(sa.desc(CollectorStat.started_at)) - .first() - ) + last_run = dbsession.execute( + sql.select(CollectorStat).order_by(sql.desc(CollectorStat.started_at)).limit(1) + ).scalar() node_errors_by_type: dict[str, list[NodeError]] = {} if last_run: - query = dbsession.query(NodeError).filter( - NodeError.timestamp == last_run.started_at + errors = dbsession.execute( + sql.select(NodeError).where(NodeError.timestamp == last_run.started_at) ) - for error in query.all(): + for error in errors.scalars(): node_errors_by_type.setdefault(str(error.error_type), []).append(error) return { diff --git a/meshinfo/views/map.py b/meshinfo/views/map.py index ba24785..505c711 100644 --- a/meshinfo/views/map.py +++ b/meshinfo/views/map.py @@ -95,8 +95,8 @@ class GeoNode: id: int name: str band: Band - latitude: float - longitude: float + latitude: float | None + longitude: float | None layer: NodeLayer @classmethod @@ -137,10 +137,10 @@ class GeoLink: type: LinkType status: LinkStatus cost: float - start_latitude: float - start_longitude: float - end_latitude: float - end_longitude: float + start_latitude: float | None + start_longitude: float | None + end_latitude: float | None + end_longitude: float | None layer: LinkLayer @property diff --git a/meshinfo/views/node.py b/meshinfo/views/node.py index b77a014..595aca6 100644 --- a/meshinfo/views/node.py +++ b/meshinfo/views/node.py @@ -1,10 +1,12 @@ from operator import attrgetter +from typing import Any from pyramid.httpexceptions import HTTPNotFound from pyramid.request import Request from pyramid.response import Response from pyramid.settings import asbool from pyramid.view import view_config, view_defaults +from sqlalchemy import sql from sqlalchemy.orm import Session, joinedload, load_only from ..aredn import LinkType, VersionChecker @@ -22,7 +24,7 @@ def node_detail(request: Request): dbsession: Session = request.dbsession version_checker: VersionChecker = request.find_service(VersionChecker) - node = dbsession.query(Node).get(node_id) + node = dbsession.get(Node, node_id) if node is None: raise HTTPNotFound("Sorry, the specified node could not be found") @@ -30,15 +32,15 @@ def node_detail(request: Request): firmware_status = version_checker.firmware(node.firmware_version) api_status = version_checker.api(node.api_version) - query = ( - dbsession.query(Link) + stmt = ( + sql.select(Link) .options(joinedload(Link.destination).load_only(Node.display_name)) .filter( Link.source_id == node.id, Link.status != LinkStatus.INACTIVE, ) ) - links = query.all() + links = dbsession.execute(stmt).scalars() graphs_by_link_type = { LinkType.RF: ("cost", "quality", "snr"), @@ -58,15 +60,13 @@ def node_detail(request: Request): @view_config(route_name="node-json", renderer="json") -def node_json(request: Request): +def node_json(request: Request) -> dict: """Dump most recent sysinfo.json for a node.""" node_id = int(request.matchdict["id"]) dbsession: Session = request.dbsession - node: Node = dbsession.query(Node).get(node_id) - - if node is None: + if not (node := dbsession.get(Node, node_id)): raise HTTPNotFound("Sorry, the specified node could not be found") return node.system_info @@ -83,9 +83,7 @@ def node_preview(request: Request): node_id = int(request.matchdict["id"]) dbsession: Session = request.dbsession - node: Node = dbsession.query(Node).get(node_id) - - if node is None: + if not (node := dbsession.get(Node, node_id)): raise HTTPNotFound("Sorry, the specified node could not be found") query = ( @@ -119,18 +117,14 @@ def node_preview(request: Request): @view_config(route_name="node-graphs", renderer="pages/node-graphs.jinja2") -def node_graphs(request: Request): +def node_graphs(request: Request) -> dict[str, Any]: """Display graphs of particular data for a node over different timeframes.""" node_id = int(request.matchdict["id"]) graph = request.matchdict["name"] dbsession: Session = request.dbsession - node = ( - dbsession.query(Node) - .options(load_only(Node.display_name, Node.id)) - .get(node_id) - ) + node = dbsession.get(Node, node_id, options=[load_only(Node.display_name, Node.id)]) return { "node": node, @@ -147,11 +141,13 @@ def __init__(self, request: Request): node_id = int(request.matchdict["id"]) dbsession: Session = request.dbsession - self.node = ( - dbsession.query(Node).options(load_only(Node.id, Node.name)).get(node_id) - ) - if self.node is None: + if not ( + node := dbsession.get( + Node, node_id, options=[load_only(Node.id, Node.name)] + ) + ): raise HTTPNotFound("Sorry, the specified node could not be found") + self.node = node self.graph_params = schema.graph_params(request.GET) self.name_in_title = asbool(request.GET.get("name_in_title", False)) @@ -159,7 +155,7 @@ def __init__(self, request: Request): self.stats: HistoricalStats = request.find_service(HistoricalStats) @view_config(match_param="name=links") - def links(self): + def links(self) -> Response: title_parts = ( self.node.name.lower() if self.name_in_title else "", "links", @@ -173,7 +169,7 @@ def links(self): ) @view_config(match_param="name=load") - def load(self): + def load(self) -> Response: title_parts = ( self.node.name.lower() if self.name_in_title else "", "load", @@ -187,7 +183,7 @@ def load(self): ) @view_config(match_param="name=uptime") - def uptime(self): + def uptime(self) -> Response: title_parts = ( self.node.name.lower() if self.name_in_title else "", "uptime", diff --git a/pdm.lock b/pdm.lock index c042e34..5e45938 100644 --- a/pdm.lock +++ b/pdm.lock @@ -4,8 +4,8 @@ [metadata] groups = ["default", "dev", "testing", "docs", "ruff", "mypy"] strategy = ["cross_platform", "inherit_metadata"] -lock_version = "4.4.1" -content_hash = "sha256:f4c52288407fc9947cf4ca54804c9115e09e5da7f94cf177544eca83f17a232a" +lock_version = "4.4.2" +content_hash = "sha256:324981210e91ae300258b863db247334dcaa39faed114bb9606674c9f3883eb6" [[package]] name = "aiohttp" diff --git a/pyproject.toml b/pyproject.toml index 278616b..3c0a82d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,6 @@ module = [ "hupper.*", "pyramid.*", "rrdtool.*", - "sqlalchemy.*", "transaction.*", "zope.*", ]