diff --git a/meshinfo/models/collector.py b/meshinfo/models/collector.py index 878a58b..1042857 100644 --- a/meshinfo/models/collector.py +++ b/meshinfo/models/collector.py @@ -1,6 +1,6 @@ import pendulum -import sqlalchemy as sa -from sqlalchemy.orm import relationship +from sqlalchemy import JSON, ForeignKey, String, UnicodeText +from sqlalchemy.orm import Mapped, mapped_column, relationship from ..poller import PollingError from .meta import Base, PDateTime @@ -11,17 +11,17 @@ class CollectorStat(Base): __tablename__ = "collector_stat" - started_at = sa.Column(PDateTime(), primary_key=True) - finished_at = sa.Column(PDateTime(), default=pendulum.now, nullable=False) - node_count = sa.Column(sa.Integer, nullable=False) - link_count = sa.Column(sa.Integer, nullable=False) - error_count = sa.Column(sa.Integer, nullable=False) - polling_duration = sa.Column(sa.Float, nullable=False) - total_duration = sa.Column(sa.Float, nullable=False) - other_stats = sa.Column(sa.JSON, nullable=False) - - node_errors = relationship( - "NodeError", foreign_keys="NodeError.timestamp", cascade="all, delete-orphan" + started_at: Mapped[PDateTime] = mapped_column(primary_key=True) + finished_at: Mapped[PDateTime] = mapped_column(default=pendulum.now) + node_count: Mapped[int] + link_count: Mapped[int] + error_count: Mapped[int] + polling_duration: Mapped[float] + total_duration: Mapped[float] + other_stats: Mapped[JSON] + + node_errors: Mapped[list["NodeError"]] = relationship( + foreign_keys="NodeError.timestamp", cascade="all, delete-orphan" ) def __repr__(self): @@ -33,10 +33,10 @@ class NodeError(Base): __tablename__ = "node_error" - timestamp = sa.Column( - PDateTime(), sa.ForeignKey("collector_stat.started_at"), primary_key=True + timestamp: Mapped[PDateTime] = mapped_column( + ForeignKey("collector_stat.started_at"), primary_key=True ) - ip_address = sa.Column(sa.String(15), primary_key=True) - dns_name = sa.Column(sa.String(70), nullable=False) - error_type = sa.Column(sa.Enum(PollingError, native_enum=False), nullable=False) - details = sa.Column(sa.UnicodeText, nullable=False) + ip_address: Mapped[str] = mapped_column(String(15), primary_key=True) + dns_name: Mapped[str] = mapped_column(String(70)) + error_type: Mapped[PollingError] + details: Mapped[UnicodeText] diff --git a/meshinfo/models/meta.py b/meshinfo/models/meta.py index da1d4bf..3fdd59f 100644 --- a/meshinfo/models/meta.py +++ b/meshinfo/models/meta.py @@ -1,7 +1,7 @@ from __future__ import annotations import pendulum -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.schema import MetaData from sqlalchemy.types import TIMESTAMP, TypeDecorator @@ -17,7 +17,10 @@ } metadata = MetaData(naming_convention=NAMING_CONVENTION) -Base = declarative_base(metadata=metadata) + + +class Base(DeclarativeBase): + metadata = metadata class PDateTime(TypeDecorator): @@ -37,5 +40,4 @@ def process_bind_param(self, value, dialect): def process_result_value(self, value, dialect): if value is not None: value = pendulum.instance(value) - return value diff --git a/meshinfo/views/node.py b/meshinfo/views/node.py index b77a014..d7a4d47 100644 --- a/meshinfo/views/node.py +++ b/meshinfo/views/node.py @@ -5,6 +5,7 @@ from pyramid.response import Response from pyramid.settings import asbool from pyramid.view import view_config, view_defaults +from sqlalchemy import sql from sqlalchemy.orm import Session, joinedload, load_only from ..aredn import LinkType, VersionChecker @@ -22,7 +23,7 @@ def node_detail(request: Request): dbsession: Session = request.dbsession version_checker: VersionChecker = request.find_service(VersionChecker) - node = dbsession.query(Node).get(node_id) + node = dbsession.get(Node, node_id) if node is None: raise HTTPNotFound("Sorry, the specified node could not be found") @@ -30,15 +31,15 @@ def node_detail(request: Request): firmware_status = version_checker.firmware(node.firmware_version) api_status = version_checker.api(node.api_version) - query = ( - dbsession.query(Link) + stmt = ( + sql.select(Link) .options(joinedload(Link.destination).load_only(Node.display_name)) .filter( Link.source_id == node.id, Link.status != LinkStatus.INACTIVE, ) ) - links = query.all() + links = dbsession.execute(stmt).scalars() graphs_by_link_type = { LinkType.RF: ("cost", "quality", "snr"), @@ -64,7 +65,7 @@ def node_json(request: Request): node_id = int(request.matchdict["id"]) dbsession: Session = request.dbsession - node: Node = dbsession.query(Node).get(node_id) + node: Node = dbsession.get(Node, node_id) if node is None: raise HTTPNotFound("Sorry, the specified node could not be found") @@ -83,7 +84,7 @@ def node_preview(request: Request): node_id = int(request.matchdict["id"]) dbsession: Session = request.dbsession - node: Node = dbsession.query(Node).get(node_id) + node: Node = dbsession.get(Node, node_id) if node is None: raise HTTPNotFound("Sorry, the specified node could not be found")