From a73b320ea95a5b3b4513fdd6c9c6100ed73508d7 Mon Sep 17 00:00:00 2001 From: Scott Searcy Date: Thu, 25 Jul 2024 12:29:50 -0700 Subject: [PATCH] WIP: optimizing DB writes Started work on optimized DB writes --- meshinfo/collector.py | 356 ++++++++++++++++++++++++------------------ 1 file changed, 201 insertions(+), 155 deletions(-) diff --git a/meshinfo/collector.py b/meshinfo/collector.py index 4a513c4..901fd97 100644 --- a/meshinfo/collector.py +++ b/meshinfo/collector.py @@ -9,9 +9,12 @@ from collections import defaultdict from collections.abc import Iterable from operator import attrgetter +from typing import Sequence +import attrs import pendulum import structlog +from sqlalchemy import sql, Row from sqlalchemy.orm import Session from structlog.contextvars import bound_contextvars @@ -25,37 +28,6 @@ logger = structlog.get_logger() -# TODO: align names so that this can just be a list -MODEL_TO_SYSINFO_ATTRS = { - "name": "node_name", - "display_name": "display_name", - "ip_address": "ip_address", - "description": "description", - "mac_address": "mac_address", - "up_time": "up_time", - "up_time_seconds": "up_time_seconds", - "load_averages": "load_averages", - "model": "model", - "board_id": "board_id", - "firmware_version": "firmware_version", - "firmware_manufacturer": "firmware_manufacturer", - "api_version": "api_version", - "latitude": "latitude", - "longitude": "longitude", - "grid_square": "grid_square", - "ssid": "ssid", - "channel": "channel", - "channel_bandwidth": "channel_bandwidth", - "band": "band", - "services": "services_json", - "active_tunnel_count": "active_tunnel_count", - "system_info": "source_json", - "link_count": "link_count", - "radio_link_count": "radio_link_count", - "dtd_link_count": "dtd_link_count", - "tunnel_link_count": "tunnel_link_count", -} - def main( local_node: str, @@ -164,7 +136,8 @@ async def collector( summary: defaultdict[str, int] = defaultdict(int) with models.session_scope(session_factory) as dbsession: - node_models = save_nodes(nodes, dbsession, count=summary) + node_updater = NodeUpdater(dbsession) + summary |= await node_updater.save_nodes(nodes) link_models = save_links(links, dbsession, count=summary) # expire data after the data has been refreshed # (otherwise the first run after a long gap will mark current stuff expired) @@ -224,149 +197,173 @@ async def collector( return +@attrs.define +class NodeUpdater: + dbsession: Session + count: defaultdict[str, int] = attrs.field(init=False) + timestamp: pendulum.DateTime = attrs.field(init=False) + _insert_nodes: list[dict] = attrs.field(factory=list, init=False) + _update_nodes: list[dict] = attrs.field(factory=list, init=False) -def expire_data( - dbsession: Session, - *, - nodes_expire: int, - links_expire: int, - count: defaultdict[str, int] | None = None, -): - """Update the status of nodes/links that have not been seen recently. + def __attrs_post_init__(self): + self.count = defaultdict(int) + self.timestamp = pendulum.now() - Args: - dbsession: SQLAlchemy database session - nodes_expire: Number of days a node is not seen before marked inactive - links_expire: Number of days a link is not seen before marked inactive - count: Default dictionary for tracking statistics + async def save_nodes(self, nodes: Iterable[SystemInfo]) -> dict[str, int]: + """Save nodes to the database. - """ + Args: + nodes: Iterable of SystemInfo objects - timestamp = pendulum.now() + Returns: + Dictionary of counts for nodes added, updated, and total - if count is None: - count = defaultdict(int) - inactive_cutoff = timestamp.subtract(days=links_expire) - count["expired: links"] = ( - dbsession.query(Link) - .filter( - Link.status == LinkStatus.RECENT, - Link.last_seen < inactive_cutoff, - ) - .update({Link.status: LinkStatus.INACTIVE}) - ) - logger.info( - "Marked inactive links", - count=count["expired: links"], - cutoff=inactive_cutoff, - ) + """ + self._insert_nodes = [] + self._update_nodes = [] - inactive_cutoff = timestamp.subtract(days=nodes_expire) - count["expired: nodes"] = ( - dbsession.query(Node) - .filter( - Node.status == NodeStatus.ACTIVE, - Node.last_seen < inactive_cutoff, - ) - .update({Node.status: NodeStatus.INACTIVE}) - ) - logger.info( - "Marked inactive nodes", - count=count["expired: nodes"], - cutoff=inactive_cutoff, - ) - return + for node in nodes: + await self._process(node) + + if self._insert_nodes: + self.dbsession.execute( + sql.insert(Node), + self._insert_nodes, + ) + if self._update_nodes: + self.dbsession.execute( + sql.update(Node), + self._update_nodes, + ) + return dict(self.count) + async def _process(self, node: SystemInfo) -> None: + """Identify data to insert/update into database. -def save_nodes( - nodes: Iterable[SystemInfo], - dbsession: Session, - *, - count: defaultdict[str, int] | None = None, -) -> list[Node]: - """Saves node information to the database. + Looks for existing nodes by WLAN MAC address and name. - Looks for existing nodes by WLAN MAC address and name. + """ - """ - if count is None: - count = defaultdict(int) - timestamp = pendulum.now() - node_models = [] - for node in nodes: - count["nodes: total"] += 1 + self.count["nodes: total"] += 1 # check to see if node exists in database by name and WLAN MAC address + data = { + "name": node.node_name, + "display_name": node.display_name, + "ip_address": node.ip_address, + "description": node.description, + "mac_address": node.mac_address, + "up_time": node.up_time, + "up_time_seconds": node.up_time_seconds, + "load_averages": node.load_averages, + "model": node.model, + "board_id": node.board_id, + "firmware_version": node.firmware_version, + "firmware_manufacturer": node.firmware_manufacturer, + "api_version": node.api_version, + "latitude": node.latitude, + "longitude": node.longitude, + "grid_square": node.grid_square, + "ssid": node.ssid, + "channel": node.channel, + "channel_bandwidth": node.channel_bandwidth, + "band": node.band, + "services": node.services_json, + "active_tunnel_count": node.active_tunnel_count, + "system_info": node.source_json, + "link_count": node.link_count, + "radio_link_count": node.radio_link_count, + "dtd_link_count": node.dtd_link_count, + "tunnel_link_count": node.tunnel_link_count, + "last_seen": self.timestamp, + "status": NodeStatus.ACTIVE, + } + with bound_contextvars(node=node.node_name): - model = get_db_model(dbsession, node) + model_id = await self._get_node_id(node) - if model is None: + if model_id is None: # create new database model - count["nodes: added"] += 1 + self.count["nodes: added"] += 1 logger.debug("Added node to database") - model = Node() - dbsession.add(model) + self._insert_nodes.append(data) else: # update database model - count["nodes: updated"] += 1 - logger.debug("Updated node in database", model=model) - node_models.append(model) - - model.last_seen = timestamp - model.status = NodeStatus.ACTIVE - - for model_attr, node_attr in MODEL_TO_SYSINFO_ATTRS.items(): - setattr(model, model_attr, getattr(node, node_attr)) + self.count["nodes: updated"] += 1 + logger.debug("Updated node in database", model=model_id) + data["id"] = model_id + self._update_nodes.append(data) + + + async def _get_node_id(self, node: SystemInfo) -> int | None: + """Get the best match database record for this node.""" + # Find the most recently seen node that matches both name and MAC address + results = self.dbsession.execute( + sql.select(Node.id, Node.status, Node.last_seen).where( + Node.mac_address == node.mac_address, + Node.name == node.node_name, + ) + ).all() + if model_id := self._get_most_recent(results): + return model_id + + # Find active node with same hardware + results = self.dbsession.execute( + sql.select(Node.id, Node.status, Node.last_seen).where( + Node.mac_address == node.mac_address, + Node.status == NodeStatus.ACTIVE, + Node.mac_address != "", + ) + ).all() + if model_id := self._get_most_recent(results): + return model_id + + # Find active node with same name + results = self.dbsession.execute( + sql.select(Node.id, Node.status, Node.last_seen).where( + Node.name == node.node_name, + Node.status == NodeStatus.ACTIVE, + ) + ).all() + if model := self._get_most_recent(results): + return model_id - logger.info("Nodes saved to database", summary=dict(count)) - return node_models + # Nothing found, treat as a new node + return None + def _get_most_recent(self, results: Sequence[Row]) -> int | None: + """Get the most recently seen node, marking the others inactive.""" + if len(results) == 0: + return None -def get_db_model(dbsession: Session, node: SystemInfo) -> Node | None: - """Get the best match database record for this node.""" - # Find the most recently seen node that matches both name and MAC address - query = dbsession.query(Node).filter( - Node.mac_address == node.mac_address, - Node.name == node.node_name, - ) - model = _get_most_recent(query.all()) - if model: - return model - - # Find active node with same hardware - query = dbsession.query(Node).filter( - Node.mac_address == node.mac_address, - Node.status == NodeStatus.ACTIVE, - Node.mac_address != "", - ) - model = _get_most_recent(query.all()) - if model: - return model + results = sorted(results, key=attrgetter("last_seen"), reverse=True) + for row in results[1:]: + if row.status == NodeStatus.ACTIVE: + logger.debug("Marking older match inactive", model=row) + self.count["nodes: inactive"] += 1 + self._update_nodes.append({"id": row.id, "status": NodeStatus.INACTIVE}) - # Find active node with same name - query = dbsession.query(Node).filter( - Node.name == node.node_name, Node.status == NodeStatus.ACTIVE - ) - model = _get_most_recent(query.all()) - if model: - return model + return results[0].id - # Nothing found, treat as a new node - return None -def _get_most_recent(results: list[Node]) -> Node | None: - """Get the most recently seen node, marking the others inactive.""" - if len(results) == 0: - return None +@attrs.define +class LinkUpdater: + dbsession: Session + count: defaultdict[str, int] = attrs.field(init=False) + timestamp: pendulum.DateTime = attrs.field(init=False) + _insert_links: list[dict] = attrs.field(factory=list, init=False) + _update_links: list[dict] = attrs.field(factory=list, init=False) - results = sorted(results, key=attrgetter("last_seen"), reverse=True) - for model in results[1:]: - if model.status == NodeStatus.ACTIVE: - logger.debug("Marking older match inactive", model=model) - model.status = NodeStatus.INACTIVE + def __attrs_post_init__(self): + self.count = defaultdict(int) + self.timestamp = pendulum.now() - return results[0] + def save(self, links: Iterable[LinkInfo]) -> dict[str, int]: + # Downgrade all "current" links to "recent" so that only ones updated are "current" + self.dbsession.execute( + sql.update(Link).where(Link.status == LinkStatus.CURRENT).values(status=LinkStatus.RECENT) + ) def save_links( @@ -381,13 +378,7 @@ def save_links( rather than using SQL triggers (i.e. how MeshMap does it). """ - if count is None: - count = defaultdict(int) - # Downgrade all "current" links to "recent" so that only ones updated are "current" - dbsession.query(Link).filter(Link.status == LinkStatus.CURRENT).update( - {Link.status: LinkStatus.RECENT} - ) active_nodes: dict[str, Node] = { node.name: node @@ -470,6 +461,61 @@ def save_links( return link_models +def expire_data( + dbsession: Session, + *, + nodes_expire: int, + links_expire: int, + count: defaultdict[str, int], +): + """Update the status of nodes/links that have not been seen recently. + + Args: + dbsession: SQLAlchemy database session + nodes_expire: Number of days a node is not seen before marked inactive + links_expire: Number of days a link is not seen before marked inactive + count: Default dictionary for tracking statistics + + """ + + timestamp = pendulum.now() + + inactive_cutoff = timestamp.subtract(days=links_expire) + stmt = ( + sql.update(Link) + .where( + Link.status == LinkStatus.RECENT, + Link.last_seen < inactive_cutoff, + ) + .values(status=LinkStatus.INACTIVE) + ) + link_count = dbsession.execute(stmt).rowcount + logger.info( + "Marked inactive links", + count=link_count, + cutoff=inactive_cutoff, + ) + + inactive_cutoff = timestamp.subtract(days=nodes_expire) + stmt = ( + sql.update(Node) + .where( + Node.status == NodeStatus.ACTIVE, + Node.last_seen < inactive_cutoff, + ) + .values(status=NodeStatus.INACTIVE) + ) + node_count = dbsession.execute(stmt).rowcount + logger.info( + "Marked inactive nodes", + count=node_count, + cutoff=inactive_cutoff, + ) + + count["expired: links"] = link_count + count["expired: nodes"] = node_count + return + def distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: """Distance between two points in kilometers via haversine.""" # convert from degrees to radians @@ -513,7 +559,7 @@ def bearing(lat1: float, lon1: float, lat2: float, lon2: float) -> float: async def save_historical_data( - nodes: list[Node], links: list[Link], stats: HistoricalStats + nodes: list[SystemInfo], links: list[LinkInfo], stats: HistoricalStats ): """Save current node and link data to our time series storage.