Skip to content

Commit

Permalink
add more typing to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertodonato committed Dec 29, 2024
1 parent dccc38c commit 0fcd90f
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 51 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from query_exporter.db import DataBase, MetricResults, Query

__all__ = ["advance_time", "query_tracker"]
__all__ = ["QueryTracker", "advance_time", "query_tracker"]


@pytest.fixture(autouse=True)
Expand Down
162 changes: 112 additions & 50 deletions tests/loop_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import asyncio
from collections import defaultdict
from collections.abc import Callable, Iterator
from collections.abc import AsyncIterator, Callable, Iterator
from decimal import Decimal
from pathlib import Path
import typing as t
from unittest.mock import ANY

from prometheus_aioexporter import MetricsRegistry
import pytest
from pytest_mock import MockerFixture
from pytest_structlog import StructuredLogCapture
import yaml

from query_exporter import loop
from query_exporter.config import (
DataBaseConfig,
load_config,
)
from query_exporter.db import DataBase
from query_exporter.config import load_config
from query_exporter.db import DataBase, DataBaseConfig

from .conftest import QueryTracker

AdvanceTime = Callable[[float], t.Awaitable[None]]


@pytest.fixture
Expand All @@ -40,10 +42,13 @@ def registry() -> Iterator[MetricsRegistry]:
yield MetricsRegistry()


MakeQueryLoop = Callable[[], loop.QueryLoop]


@pytest.fixture
async def make_query_loop(
tmp_path: Path, config_data: dict[str, t.Any], registry: MetricsRegistry
) -> Iterator[Callable[[], MetricsRegistry]]:
) -> AsyncIterator[MakeQueryLoop]:
query_loops = []

def make_loop() -> loop.QueryLoop:
Expand All @@ -64,12 +69,15 @@ def make_loop() -> loop.QueryLoop:

@pytest.fixture
async def query_loop(
make_query_loop: Callable[[], loop.QueryLoop],
) -> Iterator[loop.QueryLoop]:
make_query_loop: MakeQueryLoop,
) -> AsyncIterator[loop.QueryLoop]:
yield make_query_loop()


def metric_values(metric, by_labels=()):
MetricValues = list[int | float] | dict[tuple[str], list[int | float]]


def metric_values(metric, by_labels: tuple[str] = ()) -> MetricValues:
"""Return values for the metric."""
if metric._type == "gauge":
suffix = ""
Expand Down Expand Up @@ -145,7 +153,9 @@ def test_expire_no_labels(self) -> None:


class TestQueryLoop:
async def test_start(self, query_tracker, query_loop) -> None:
async def test_start(
self, query_tracker: QueryTracker, query_loop
) -> None:
await query_loop.start()
timed_call = query_loop._timed_calls["q"]
assert timed_call.running
Expand All @@ -158,7 +168,7 @@ async def test_stop(self, query_loop) -> None:
assert not timed_call.running

async def test_run_query(
self, query_tracker, query_loop, registry
self, query_tracker: QueryTracker, query_loop: loop.QueryLoop, registry
) -> None:
await query_loop.start()
await query_tracker.wait_results()
Expand All @@ -173,24 +183,24 @@ async def test_run_query(

async def test_run_scheduled_query(
self,
mocker,
advance_time,
query_tracker,
registry,
config_data,
make_query_loop,
mocker: MockerFixture,
advance_time: AdvanceTime,
query_tracker: QueryTracker,
registry: MetricsRegistry,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
event_loop = asyncio.get_running_loop()

def croniter(*args):
def croniter(*args: t.Any) -> float:
while True:
# sync croniter time with the loop one
yield event_loop.time() + 60

mock_croniter = mocker.patch.object(loop, "croniter")
mock_croniter.side_effect = croniter
# ensure that both clocks advance in sync
mocker.patch.object(loop.time, "time", lambda: event_loop.time())
mocker.patch.object(loop.time, "time", lambda: event_loop.time()) # type: ignore

del config_data["queries"]["q"]["interval"]
config_data["queries"]["q"]["schedule"] = "*/2 * * * *"
Expand All @@ -199,7 +209,11 @@ def croniter(*args):
mock_croniter.assert_called_once()

async def test_run_query_with_parameters(
self, query_tracker, registry, config_data, make_query_loop
self,
query_tracker: QueryTracker,
registry: MetricsRegistry,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["metrics"]["m"]["type"] = "counter"
config_data["metrics"]["m"]["labels"] = ["l"]
Expand All @@ -224,7 +238,11 @@ async def test_run_query_with_parameters(
}

async def test_run_query_null_value(
self, query_tracker, registry, config_data, make_query_loop
self,
query_tracker: QueryTracker,
registry: MetricsRegistry,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["queries"]["q"]["sql"] = "SELECT NULL AS m"
query_loop = make_query_loop()
Expand All @@ -242,10 +260,10 @@ async def test_run_query_null_value(
)
async def test_run_query_counter(
self,
query_tracker,
registry,
config_data,
make_query_loop,
query_tracker: QueryTracker,
registry: MetricsRegistry,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
increment: bool,
value: float,
) -> None:
Expand All @@ -264,7 +282,11 @@ async def test_run_query_counter(
assert metric_values(metric) == [value]

async def test_run_query_metrics_with_database_labels(
self, query_tracker, registry, config_data, make_query_loop
self,
query_tracker: QueryTracker,
registry: MetricsRegistry,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["databases"] = {
"db1": {"dsn": "sqlite://", "labels": {"l1": "v1", "l2": "v2"}},
Expand All @@ -281,7 +303,7 @@ async def test_run_query_metrics_with_database_labels(
}

async def test_update_metric_decimal_value(
self, registry, make_query_loop
self, registry: MetricsRegistry, make_query_loop
) -> None:
db = DataBase(DataBaseConfig(name="db", dsn="sqlite://"))
query_loop = make_query_loop()
Expand All @@ -294,8 +316,8 @@ async def test_update_metric_decimal_value(
async def test_run_query_log(
self,
log: StructuredLogCapture,
query_tracker,
query_loop,
query_tracker: QueryTracker,
query_loop: loop.QueryLoop,
) -> None:
await query_loop.start()
await query_tracker.wait_queries()
Expand Down Expand Up @@ -326,9 +348,9 @@ async def test_run_query_log(
async def test_run_query_log_labels(
self,
log: StructuredLogCapture,
query_tracker,
config_data,
make_query_loop,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["metrics"]["m"]["labels"] = ["l"]
config_data["queries"]["q"]["sql"] = 'SELECT 100.0 AS m, "foo" AS l'
Expand All @@ -345,7 +367,11 @@ async def test_run_query_log_labels(
)

async def test_run_query_increase_db_error_count(
self, query_tracker, config_data, make_query_loop, registry
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
) -> None:
config_data["databases"]["db"]["dsn"] = "sqlite:////invalid"
query_loop = make_query_loop()
Expand All @@ -355,7 +381,12 @@ async def test_run_query_increase_db_error_count(
assert metric_values(queries_metric) == [1.0]

async def test_run_query_increase_database_error_count(
self, mocker, query_tracker, config_data, make_query_loop, registry
self,
mocker,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
) -> None:
query_loop = make_query_loop()
db = query_loop._databases["db"]
Expand All @@ -367,7 +398,11 @@ async def test_run_query_increase_database_error_count(
assert metric_values(queries_metric) == [1.0]

async def test_run_query_increase_query_error_count(
self, query_tracker, config_data, make_query_loop, registry
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
) -> None:
config_data["queries"]["q"]["sql"] = "SELECT 100.0 AS a, 200.0 AS b"
query_loop = make_query_loop()
Expand All @@ -379,7 +414,11 @@ async def test_run_query_increase_query_error_count(
}

async def test_run_query_increase_timeout_count(
self, query_tracker, config_data, make_query_loop, registry
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
) -> None:
config_data["queries"]["q"]["timeout"] = 0.1
query_loop = make_query_loop()
Expand All @@ -399,7 +438,10 @@ async def execute(sql, parameters):
}

async def test_run_query_at_interval(
self, advance_time, query_tracker, query_loop
self,
advance_time: AdvanceTime,
query_tracker: QueryTracker,
query_loop: loop.QueryLoop,
) -> None:
await query_loop.start()
await advance_time(0) # kick the first run
Expand All @@ -413,7 +455,10 @@ async def test_run_query_at_interval(
assert len(query_tracker.queries) == 2

async def test_run_timed_queries_invalid_result_count(
self, query_tracker, config_data, make_query_loop
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["queries"]["q"]["sql"] = "SELECT 100.0 AS a, 200.0 AS b"
config_data["queries"]["q"]["interval"] = 1.0
Expand All @@ -432,7 +477,10 @@ async def test_run_timed_queries_invalid_result_count(
assert len(query_tracker.results) == 0

async def test_run_timed_queries_invalid_result_count_stop_task(
self, query_tracker, config_data, make_query_loop
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["queries"]["q"]["sql"] = "SELECT 100.0 AS a, 200.0 AS b"
config_data["queries"]["q"]["interval"] = 1.0
Expand All @@ -446,7 +494,11 @@ async def test_run_timed_queries_invalid_result_count_stop_task(
assert query_loop._timed_calls == {}

async def test_run_timed_queries_not_removed_if_not_failing_on_all_dbs(
self, tmp_path, query_tracker, config_data, make_query_loop
self,
tmp_path: Path,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
db1 = tmp_path / "db1.sqlite"
db2 = tmp_path / "db2.sqlite"
Expand Down Expand Up @@ -485,7 +537,10 @@ async def test_run_timed_queries_not_removed_if_not_failing_on_all_dbs(
assert len(query_tracker.failures) == 1

async def test_run_aperiodic_queries(
self, query_tracker, config_data, make_query_loop
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
del config_data["queries"]["q"]["interval"]
query_loop = make_query_loop()
Expand All @@ -495,7 +550,10 @@ async def test_run_aperiodic_queries(
assert len(query_tracker.queries) == 2

async def test_run_aperiodic_queries_invalid_result_count(
self, query_tracker, config_data, make_query_loop
self,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
config_data["queries"]["q"]["sql"] = "SELECT 100.0 AS a, 200.0 AS b"
del config_data["queries"]["q"]["interval"]
Expand All @@ -508,7 +566,11 @@ async def test_run_aperiodic_queries_invalid_result_count(
assert len(query_tracker.queries) == 1

async def test_run_aperiodic_queries_not_removed_if_not_failing_on_all_dbs(
self, tmp_path, query_tracker, config_data, make_query_loop
self,
tmp_path: Path,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
) -> None:
db1 = tmp_path / "db1.sqlite"
db2 = tmp_path / "db2.sqlite"
Expand Down Expand Up @@ -547,12 +609,12 @@ async def test_run_aperiodic_queries_not_removed_if_not_failing_on_all_dbs(

async def test_clear_expired_series(
self,
tmp_path,
advance_time,
query_tracker,
config_data,
make_query_loop,
registry,
tmp_path: Path,
advance_time: AdvanceTime,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry: MetricsRegistry,
) -> None:
db = tmp_path / "db.sqlite"
config_data["databases"]["db"]["dsn"] = f"sqlite:///{db}"
Expand Down

0 comments on commit 0fcd90f

Please sign in to comment.