From 7cd6251f530d85d5d6b6e039ac44846f860c459e Mon Sep 17 00:00:00 2001
From: Jonas Keeling <jonas.keeling@aiven.io>
Date: Thu, 19 Sep 2024 14:31:57 +0200
Subject: [PATCH] feat: improve health check to fail if schema_reader raises
 exceptions

---
 karapace/karapace.py                     | 18 ++++-
 karapace/schema_reader.py                | 39 ++++++++++-
 karapace/schema_registry_apis.py         | 13 ++--
 stubs/confluent_kafka/__init__.pyi       |  2 +
 stubs/confluent_kafka/admin/__init__.pyi |  3 +-
 stubs/confluent_kafka/cimpl.pyi          |  7 ++
 tests/integration/test_health_check.py   | 27 ++++++++
 tests/unit/test_schema_reader.py         | 83 +++++++++++++++++++++++-
 8 files changed, 181 insertions(+), 11 deletions(-)
 create mode 100644 tests/integration/test_health_check.py

diff --git a/karapace/karapace.py b/karapace/karapace.py
index 28e26cf91..4afd3dc08 100644
--- a/karapace/karapace.py
+++ b/karapace/karapace.py
@@ -11,6 +11,7 @@
 from functools import partial
 from http import HTTPStatus
 from karapace.config import Config
+from karapace.dataclasses import default_dataclass
 from karapace.rapu import HTTPRequest, HTTPResponse, RestApp
 from karapace.typing import JsonObject
 from karapace.utils import json_encode
@@ -21,7 +22,14 @@
 import aiohttp.web
 import time
 
-HealthHook: TypeAlias = Callable[[], Awaitable[JsonObject]]
+
+@default_dataclass
+class HealthCheck:
+    status: JsonObject
+    healthy: bool
+
+
+HealthHook: TypeAlias = Callable[[], Awaitable[HealthCheck]]
 
 
 class KarapaceBase(RestApp):
@@ -95,11 +103,15 @@ async def health(self, _request: Request) -> aiohttp.web.Response:
             "process_uptime_sec": int(time.monotonic() - self._process_start_time),
             "karapace_version": __version__,
         }
+        status_code = HTTPStatus.OK
         for hook in self.health_hooks:
-            resp.update(await hook())
+            check = await hook()
+            resp.update(check.status)
+            if not check.healthy:
+                status_code = HTTPStatus.SERVICE_UNAVAILABLE
         return aiohttp.web.Response(
             body=json_encode(resp, binary=True, compact=True),
-            status=HTTPStatus.OK.value,
+            status=status_code.value,
             headers={"Content-Type": "application/json"},
         )
 
diff --git a/karapace/schema_reader.py b/karapace/schema_reader.py
index 85e822ed1..293b9db87 100644
--- a/karapace/schema_reader.py
+++ b/karapace/schema_reader.py
@@ -20,7 +20,7 @@
     UnknownTopicOrPartitionError,
 )
 from avro.schema import Schema as AvroSchema
-from confluent_kafka import Message, TopicPartition
+from confluent_kafka import Message, TopicCollection, TopicPartition
 from contextlib import closing, ExitStack
 from enum import Enum
 from jsonschema.validators import Draft7Validator
@@ -46,6 +46,7 @@
 from threading import Event, Thread
 from typing import Final, Mapping, Sequence
 
+import asyncio
 import json
 import logging
 import time
@@ -60,6 +61,11 @@
 KAFKA_CLIENT_CREATION_TIMEOUT_SECONDS: Final = 2.0
 SCHEMA_TOPIC_CREATION_TIMEOUT_SECONDS: Final = 5.0
 
+# If handle_messages throws at least UNHEALTHY_CONSECUTIVE_ERRORS
+# for UNHEALTHY_TIMEOUT_SECONDS the SchemaReader will be reported unhealthy
+UNHEALTHY_TIMEOUT_SECONDS: Final = 10.0
+UNHEALTHY_CONSECUTIVE_ERRORS: Final = 3
+
 # For good startup performance the consumption of multiple
 # records for each consume round is essential.
 # Consumer default is 1 message for each consume call and after
@@ -174,6 +180,9 @@ def __init__(
         self.start_time = time.monotonic()
         self.startup_previous_processed_offset = 0
 
+        self.consecutive_unexpected_errors: int = 0
+        self.consecutive_unexpected_errors_start: float = 0
+
     def close(self) -> None:
         LOG.info("Closing schema_reader")
         self._stop_schema_reader.set()
@@ -247,13 +256,41 @@ def run(self) -> None:
                     self.offset = self._get_beginning_offset()
                 try:
                     self.handle_messages()
+                    self.consecutive_unexpected_errors = 0
                 except ShutdownException:
                     self._stop_schema_reader.set()
                     shutdown()
                 except Exception as e:  # pylint: disable=broad-except
                     self.stats.unexpected_exception(ex=e, where="schema_reader_loop")
+                    self.consecutive_unexpected_errors += 1
+                    if self.consecutive_unexpected_errors == 1:
+                        self.consecutive_unexpected_errors_start = time.monotonic()
                     LOG.warning("Unexpected exception in schema reader loop - %s", e)
 
+    async def is_healthy(self) -> bool:
+        if (
+            self.consecutive_unexpected_errors >= UNHEALTHY_CONSECUTIVE_ERRORS
+            and (duration := time.monotonic() - self.consecutive_unexpected_errors_start) >= UNHEALTHY_TIMEOUT_SECONDS
+        ):
+            LOG.warning(
+                "Health check failed with %s consecutive errors in %s seconds", self.consecutive_unexpected_errors, duration
+            )
+            return False
+
+        try:
+            # Explicitly check if topic exists.
+            # This needs to be done because in case of missing topic the consumer will not repeat the error
+            # on conscutive consume calls and instead will return empty list.
+            assert self.admin_client is not None
+            topic = self.config["topic_name"]
+            res = self.admin_client.describe_topics(TopicCollection([topic]))
+            res = await asyncio.wrap_future(res[topic])
+        except Exception as e:  # pylint: disable=broad-except
+            LOG.warning("Health check failed with %r", e)
+            return False
+
+        return True
+
     def _get_beginning_offset(self) -> int:
         assert self.consumer is not None, "Thread must be started"
 
diff --git a/karapace/schema_registry_apis.py b/karapace/schema_registry_apis.py
index d3b90dac6..6776e8ba6 100644
--- a/karapace/schema_registry_apis.py
+++ b/karapace/schema_registry_apis.py
@@ -28,13 +28,13 @@
     SubjectSoftDeletedException,
     VersionNotFoundException,
 )
-from karapace.karapace import KarapaceBase
+from karapace.karapace import HealthCheck, KarapaceBase
 from karapace.protobuf.exception import ProtobufUnresolvedDependencyException
 from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE, SERVER_NAME
 from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema, Versioner
 from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping
 from karapace.schema_registry import KarapaceSchemaRegistry
-from karapace.typing import JsonData, JsonObject, SchemaId, Subject, Version
+from karapace.typing import JsonData, SchemaId, Subject, Version
 from karapace.utils import JSONDecodeError
 from typing import Any
 
@@ -98,7 +98,7 @@ def __init__(self, config: Config) -> None:
         self.app.on_startup.append(self._create_forward_client)
         self.health_hooks.append(self.schema_registry_health)
 
-    async def schema_registry_health(self) -> JsonObject:
+    async def schema_registry_health(self) -> HealthCheck:
         resp = {}
         if self._auth is not None:
             resp["schema_registry_authfile_timestamp"] = self._auth.authfile_last_modified
@@ -115,7 +115,12 @@ async def schema_registry_health(self) -> JsonObject:
         resp["schema_registry_primary_url"] = cs.primary_url
         resp["schema_registry_coordinator_running"] = cs.is_running
         resp["schema_registry_coordinator_generation_id"] = cs.group_generation_id
-        return resp
+
+        healthy = True
+        if not await self.schema_registry.schema_reader.is_healthy():
+            healthy = False
+
+        return HealthCheck(status=resp, healthy=healthy)
 
     async def _start_schema_registry(self, app: aiohttp.web.Application) -> None:  # pylint: disable=unused-argument
         """Callback for aiohttp.Application.on_startup"""
diff --git a/stubs/confluent_kafka/__init__.pyi b/stubs/confluent_kafka/__init__.pyi
index 175569fb4..e27cf4880 100644
--- a/stubs/confluent_kafka/__init__.pyi
+++ b/stubs/confluent_kafka/__init__.pyi
@@ -8,6 +8,7 @@ from .cimpl import (
     TIMESTAMP_CREATE_TIME,
     TIMESTAMP_LOG_APPEND_TIME,
     TIMESTAMP_NOT_AVAILABLE,
+    TopicCollection,
     TopicPartition,
 )
 
@@ -22,4 +23,5 @@ __all__ = (
     "TIMESTAMP_LOG_APPEND_TIME",
     "TIMESTAMP_NOT_AVAILABLE",
     "TopicPartition",
+    "TopicCollection",
 )
diff --git a/stubs/confluent_kafka/admin/__init__.pyi b/stubs/confluent_kafka/admin/__init__.pyi
index 02abcc033..1dafa51b8 100644
--- a/stubs/confluent_kafka/admin/__init__.pyi
+++ b/stubs/confluent_kafka/admin/__init__.pyi
@@ -4,7 +4,7 @@ from ._listoffsets import ListOffsetsResultInfo, OffsetSpec
 from ._metadata import BrokerMetadata, ClusterMetadata, PartitionMetadata, TopicMetadata
 from ._resource import ResourceType
 from concurrent.futures import Future
-from confluent_kafka import IsolationLevel, TopicPartition
+from confluent_kafka import IsolationLevel, TopicCollection, TopicPartition
 from typing import Callable
 
 __all__ = (
@@ -52,3 +52,4 @@ class AdminClient:
     def describe_configs(
         self, resources: list[ConfigResource], request_timeout: float = -1
     ) -> dict[ConfigResource, Future[dict[str, ConfigEntry]]]: ...
+    def describe_topics(self, topics: TopicCollection) -> dict[str, Future]: ...
diff --git a/stubs/confluent_kafka/cimpl.pyi b/stubs/confluent_kafka/cimpl.pyi
index 74760897c..72e83ed00 100644
--- a/stubs/confluent_kafka/cimpl.pyi
+++ b/stubs/confluent_kafka/cimpl.pyi
@@ -47,6 +47,13 @@ class TopicPartition:
         self.leader_epoch: int | None
         self.error: KafkaError | None
 
+class TopicCollection:
+    def __init__(
+        self,
+        topic_names: list[str],
+    ) -> None:
+        self.topic_names: list[str]
+
 class Message:
     def offset(self) -> int: ...
     def timestamp(self) -> tuple[int, int]: ...
diff --git a/tests/integration/test_health_check.py b/tests/integration/test_health_check.py
new file mode 100644
index 000000000..c4958651e
--- /dev/null
+++ b/tests/integration/test_health_check.py
@@ -0,0 +1,27 @@
+"""
+Copyright (c) 2024 Aiven Ltd
+See LICENSE for details
+"""
+
+from karapace.client import Client
+from karapace.kafka.admin import KafkaAdminClient
+from tenacity import retry, stop_after_delay, wait_fixed
+from tests.integration.utils.cluster import RegistryDescription
+
+import http
+
+
+async def test_health_check(
+    registry_cluster: RegistryDescription, registry_async_client: Client, admin_client: KafkaAdminClient
+) -> None:
+    res = await registry_async_client.get("/_health")
+    assert res.ok
+
+    admin_client.delete_topic(registry_cluster.schemas_topic)
+
+    @retry(stop=stop_after_delay(10), wait=wait_fixed(1), reraise=True)
+    async def check_health():
+        res = await registry_async_client.get("/_health")
+        assert res.status_code == http.HTTPStatus.SERVICE_UNAVAILABLE, "should report unhealthy after topic has been deleted"
+
+    await check_health()
diff --git a/tests/unit/test_schema_reader.py b/tests/unit/test_schema_reader.py
index 7026ea853..7968fa12e 100644
--- a/tests/unit/test_schema_reader.py
+++ b/tests/unit/test_schema_reader.py
@@ -6,7 +6,7 @@
 """
 
 from _pytest.logging import LogCaptureFixture
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import Future, ThreadPoolExecutor
 from confluent_kafka import Message
 from dataclasses import dataclass
 from karapace.config import DEFAULTS
@@ -25,9 +25,10 @@
 )
 from karapace.schema_type import SchemaType
 from karapace.typing import SchemaId, Version
+from pytest import MonkeyPatch
 from tests.base_testcase import BaseTestCase
 from tests.utils import schema_protobuf_invalid
-from typing import Callable, List, Tuple
+from typing import Callable, List, Optional, Tuple
 from unittest.mock import Mock
 
 import confluent_kafka
@@ -325,6 +326,84 @@ def test_handle_msg_delete_subject_logs(caplog: LogCaptureFixture) -> None:
             assert log.message == "Hard delete: version: Version(2) for subject: 'test-subject' did not exist, should have"
 
 
+@dataclass
+class HealthCheckTestCase(BaseTestCase):
+    current_time: float
+    consecutive_unexpected_errors: int
+    consecutive_unexpected_errors_start: float
+    healthy: bool
+    check_topic_error: Optional[Exception] = None
+
+
+@pytest.mark.parametrize(
+    "testcase",
+    [
+        HealthCheckTestCase(
+            test_name="No errors",
+            current_time=0,
+            consecutive_unexpected_errors=0,
+            consecutive_unexpected_errors_start=0,
+            healthy=True,
+        ),
+        HealthCheckTestCase(
+            test_name="10 errors in 5 seconds",
+            current_time=5,
+            consecutive_unexpected_errors=10,
+            consecutive_unexpected_errors_start=0,
+            healthy=True,
+        ),
+        HealthCheckTestCase(
+            test_name="1 error in 20 seconds",
+            current_time=20,
+            consecutive_unexpected_errors=1,
+            consecutive_unexpected_errors_start=0,
+            healthy=True,
+        ),
+        HealthCheckTestCase(
+            test_name="3 errors in 10 seconds",
+            current_time=10,
+            consecutive_unexpected_errors=3,
+            consecutive_unexpected_errors_start=0,
+            healthy=False,
+        ),
+        HealthCheckTestCase(
+            test_name="check topic error",
+            current_time=5,
+            consecutive_unexpected_errors=1,
+            consecutive_unexpected_errors_start=0,
+            healthy=False,
+            check_topic_error=Exception("Somethings wrong"),
+        ),
+    ],
+)
+async def test_schema_reader_health_check(testcase: HealthCheckTestCase, monkeypatch: MonkeyPatch) -> None:
+    offset_watcher = OffsetWatcher()
+    key_formatter_mock = Mock()
+    admin_client_mock = Mock()
+
+    emtpy_future = Future()
+    if testcase.check_topic_error:
+        emtpy_future.set_exception(testcase.check_topic_error)
+    else:
+        emtpy_future.set_result(None)
+    admin_client_mock.describe_topics.return_value = {DEFAULTS["topic_name"]: emtpy_future}
+
+    schema_reader = KafkaSchemaReader(
+        config=DEFAULTS,
+        offset_watcher=offset_watcher,
+        key_formatter=key_formatter_mock,
+        master_coordinator=None,
+        database=InMemoryDatabase(),
+    )
+
+    monkeypatch.setattr(time, "monotonic", lambda: testcase.current_time)
+    schema_reader.admin_client = admin_client_mock
+    schema_reader.consecutive_unexpected_errors = testcase.consecutive_unexpected_errors
+    schema_reader.consecutive_unexpected_errors_start = testcase.consecutive_unexpected_errors_start
+
+    assert await schema_reader.is_healthy() == testcase.healthy
+
+
 @dataclass
 class KafkaMessageHandlingErrorTestCase(BaseTestCase):
     key: bytes