Skip to content

Commit

Permalink
Introduce schema backup interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
aiven-anton committed May 3, 2023
1 parent 1e27cc1 commit 6bfadbb
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 164 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ disable=
no-name-in-module,
use-list-literal,
use-dict-literal,
no-value-for-parameter,


[FORMAT]
Expand Down
169 changes: 78 additions & 91 deletions karapace/backup/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,36 @@
"""
from __future__ import annotations

from .backend import BaseBackupReader, BaseBackupWriter, BaseItemsBackupReader
from .consumer import PollTimeout
from .encoders import encode_key, encode_value
from .errors import BackupError, PartitionCountError, StaleConsumerError
from .v1 import SchemaBackupV1Reader
from .v2 import AnonymizeAvroWriter, SchemaBackupV2Reader, SchemaBackupV2Writer
from enum import Enum
from functools import partial
from kafka import KafkaConsumer, KafkaProducer
from kafka.admin import KafkaAdminClient
from kafka.consumer.fetcher import ConsumerRecord
from kafka.errors import TopicAlreadyExistsError
from kafka.structs import PartitionMetadata, TopicPartition
from karapace import constants
from karapace.backup.consumer import PollTimeout
from karapace.backup.errors import BackupError, PartitionCountError, StaleConsumerError
from karapace.config import Config
from karapace.key_format import KeyFormatter
from karapace.schema_reader import new_schema_topic_from_config
from karapace.typing import JsonData, JsonObject
from karapace.utils import assert_never, json_decode, json_encode, KarapaceKafkaClient
from karapace.utils import assert_never, KarapaceKafkaClient
from pathlib import Path
from tempfile import mkstemp
from tenacity import retry, RetryCallState, stop_after_delay, wait_fixed
from typing import AbstractSet, Callable, Collection, Final, Generator, IO, Iterable, Literal, TextIO
from typing_extensions import TypeAlias
from typing import AbstractSet, Callable, Collection, Final, Generator, IO, Literal, TextIO

import contextlib
import karapace.backup.v1.disk_format
import karapace.backup.v2.disk_format
import logging
import os
import sys

LOG = logging.getLogger(__name__)

# Schema topic has single partition.
# Use of this in `producer.send` disables the partitioner to calculate which partition the data is sent.
PARTITION_ZERO: Final = 0


class BackupVersion(Enum):
ANONYMIZE_AVRO = -1
Expand All @@ -53,22 +50,29 @@ def marker(self) -> str:

@classmethod
def by_marker(cls, marker: str) -> BackupVersion:
for version in cls:
if version is BackupVersion.ANONYMIZE_AVRO:
continue
try:
version_marker = version.marker
except AttributeError:
continue
if marker == version_marker:
return version
raise ValueError("No BackupVersion matches the given marker")
try:
return {BackupVersion.V2.marker: BackupVersion.V2}[marker]
except KeyError:
# pylint: disable=raise-missing-from
raise ValueError("No BackupVersion matches the given marker")

@property
def reader(self) -> type[BaseBackupReader]:
if self is BackupVersion.V2 or self is BackupVersion.ANONYMIZE_AVRO:
return SchemaBackupV2Reader
if self is BackupVersion.V1:
return SchemaBackupV1Reader
assert_never(self)

ValidCreateBackupVersion: TypeAlias = Literal[
BackupVersion.V2,
BackupVersion.ANONYMIZE_AVRO,
]
@property
def writer(self) -> type[BaseBackupWriter]:
if self is BackupVersion.V2:
return SchemaBackupV2Writer
if self is BackupVersion.ANONYMIZE_AVRO:
return AnonymizeAvroWriter
if self is BackupVersion.V1:
raise AttributeError("Cannot produce backups for V1")
assert_never(self)


def __before_sleep(description: str) -> Callable[[RetryCallState], None]:
Expand Down Expand Up @@ -299,37 +303,39 @@ def check_dst() -> None:
pass


def _read_backup_file_version(fp: IO[str]) -> BackupVersion:
def _discover_reader_backend(fp: IO[str]) -> type[BaseBackupReader[IO[str]]]:
try:
version = BackupVersion.by_marker(fp.read(4))
except ValueError:
return BackupVersion.V1
return SchemaBackupV1Reader
finally:
# Seek back to start.
fp.seek(0)
# Consume until linefeed.
fp.readline()
return version
return version.reader


class SchemaBackup:
def __init__(self, config: Config, backup_path: str, topic_option: str | None = None) -> None:
self.config = config
self.backup_location = backup_path
self.topic_name: str = topic_option or self.config["topic_name"]
self.timeout_ms = 1000
self.timeout_kafka_producer = 5

def __init__(
self,
config: Config,
backup_path: str,
topic_option: str | None = None,
) -> None:
self.config: Final = config
self.backup_location: Final = backup_path
self.topic_name: Final[str] = topic_option or self.config["topic_name"]
self.timeout_ms: Final = 1000
self.timeout_kafka_producer: Final = 5
self.producer_exception: Exception | None = None

# Schema key formatter
self.key_formatter = None
if self.topic_name == constants.DEFAULT_SCHEMA_TOPIC or self.config.get("force_key_correction", False):
self.key_formatter = KeyFormatter()

def _restore_items(self, producer: KafkaProducer, items: Iterable[tuple[str, str]]) -> None:
for item in items:
self._handle_restore_message(producer, item)
self.key_formatter: Final = (
KeyFormatter()
if self.topic_name == constants.DEFAULT_SCHEMA_TOPIC or self.config.get("force_key_correction", False)
else None
)

def restore_backup(self) -> None:
if not os.path.exists(self.backup_location):
Expand All @@ -341,37 +347,41 @@ def restore_backup(self) -> None:
LOG.info("Starting backup restore for topic: %r", self.topic_name)

with open(self.backup_location, encoding="utf8") as fp:
version = _read_backup_file_version(fp)
LOG.info("Identified backup format version: %s", version.name)
if version is BackupVersion.V2:
items = karapace.backup.v2.disk_format.items_from_file(fp)
elif version is BackupVersion.V1:
items = karapace.backup.v1.disk_format.items_from_file(fp)
else:
assert version is not BackupVersion.ANONYMIZE_AVRO
assert_never(version)
self._restore_items(producer, items)
backend_type = _discover_reader_backend(fp)
backend = (
backend_type(
key_encoder=partial(encode_key, key_formatter=self.key_formatter),
value_encoder=encode_value,
)
if issubclass(backend_type, BaseItemsBackupReader)
else backend_type()
)
LOG.info("Identified backup backend: %s", backend.__class__.__name__)
for instruction in backend.read(
topic_name=self.topic_name,
buffer=fp,
):
LOG.debug(
"Sending kafka msg key: %r, value: %r",
instruction.key,
instruction.value,
)
producer.send(
instruction.topic_name,
key=instruction.key,
value=instruction.value,
partition=instruction.partition,
).add_errback(self.producer_error_callback)
producer.flush(timeout=self.timeout_kafka_producer)
if self.producer_exception is not None:
raise BackupError("Error while producing restored messages") from self.producer_exception

def producer_error_callback(self, exception: Exception) -> None:
self.producer_exception = exception

def _handle_restore_message(self, producer: KafkaProducer, item: tuple[str, str]) -> None:
key = self.encode_key(item[0])
value = encode_value(item[1])
LOG.debug("Sending kafka msg key: %r, value: %r", key, value)
producer.send(
self.topic_name,
key=key,
value=value,
partition=PARTITION_ZERO,
).add_errback(self.producer_error_callback)

def create(
self,
version: ValidCreateBackupVersion,
version: Literal[BackupVersion.V2, BackupVersion.ANONYMIZE_AVRO],
*,
poll_timeout: PollTimeout | None = None,
overwrite: bool | None = None,
Expand All @@ -389,12 +399,7 @@ def create(
:raises StaleConsumerError: if no records are received within the given ``poll_timeout`` and the target offset
has not been reached yet.
"""
if version is BackupVersion.V2:
serialize_record = karapace.backup.v2.disk_format.serialize_record
elif version is BackupVersion.ANONYMIZE_AVRO:
serialize_record = karapace.backup.v2.disk_format.anonymize_avro_schema_message
else:
assert_never(version)
backend = version.writer()

if poll_timeout is None:
poll_timeout = PollTimeout.default()
Expand Down Expand Up @@ -424,8 +429,9 @@ def create(
if len(records) == 0:
raise StaleConsumerError(topic_partition, start_offset, end_offset, last_offset, poll_timeout)
for record in records:
fp.write(serialize_record(record.key, record.value))
backend.store_record(fp, record)
record_count += 1

last_offset = record.offset # pylint: disable=undefined-loop-variable
if last_offset >= end_offset:
break
Expand All @@ -438,22 +444,3 @@ def create(
),
file=sys.stderr,
)

def encode_key(self, key: JsonObject | str) -> bytes | None:
if key == "null":
return None
if not self.key_formatter:
if isinstance(key, str):
return key.encode("utf8")
return json_encode(key, sort_keys=False, binary=True, compact=False)
if isinstance(key, str):
key = json_decode(key, JsonObject)
return self.key_formatter.format_key(key)


def encode_value(value: JsonData) -> bytes | None:
if value == "null":
return None
if isinstance(value, str):
return value.encode("utf8")
return json_encode(value, compact=True, sort_keys=False, binary=True)
104 changes: 104 additions & 0 deletions karapace/backup/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""
from __future__ import annotations

from kafka.consumer.fetcher import ConsumerRecord
from karapace.typing import JsonData, JsonObject
from typing import Callable, Final, Generator, Generic, IO, Iterator, Optional, Sequence, TypeVar, Union
from typing_extensions import TypeAlias

import abc
import dataclasses
import logging

logger = logging.getLogger(__name__)


# Schema topic has single partition. Use of this in `producer.send` disables the
# partitioner to calculate which partition the data is sent.
PARTITION_ZERO: Final = 0


@dataclasses.dataclass(frozen=True)
class ProducerSend:
topic_name: str
value: bytes | None
key: bytes | None
headers: Sequence[tuple[bytes | None, bytes | None]] | None = None
partition: int | None = None
timestamp_ms: int | None = None


KeyEncoder: TypeAlias = Callable[[Union[JsonObject, str]], Optional[bytes]]
ValueEncoder: TypeAlias = Callable[[JsonData], Optional[bytes]]
B = TypeVar("B", bound="IO[bytes] | IO[str]")


class BaseBackupReader(abc.ABC, Generic[B]):
@abc.abstractmethod
def read(
self,
topic_name: str,
buffer: B,
) -> Iterator[ProducerSend]:
...


class BaseItemsBackupReader(BaseBackupReader[IO[str]]):
def __init__(
self,
key_encoder: KeyEncoder,
value_encoder: ValueEncoder,
) -> None:
self.key_encoder: Final = key_encoder
self.value_encoder: Final = value_encoder

@staticmethod
@abc.abstractmethod
def items_from_file(fp: IO[str]) -> Iterator[tuple[str, str]]:
...

def read(
self,
topic_name: str,
buffer: IO[str],
) -> Generator[ProducerSend, None, None]:
for item in self.items_from_file(buffer):
key, value = item
yield ProducerSend(
topic_name=topic_name,
key=self.key_encoder(key),
value=self.value_encoder(value),
partition=PARTITION_ZERO,
)


class BaseBackupWriter(abc.ABC, Generic[B]):
@classmethod
@abc.abstractmethod
def store_record(
cls,
buffer: B,
record: ConsumerRecord,
) -> None:
...


class BaseKVBackupWriter(BaseBackupWriter[IO[str]]):
@classmethod
def store_record(
cls,
buffer: IO[str],
record: ConsumerRecord,
) -> None:
buffer.write(cls.serialize_record(record.key, record.value))

@staticmethod
@abc.abstractmethod
def serialize_record(
key_bytes: bytes | None,
value_bytes: bytes | None,
) -> str:
...
Loading

0 comments on commit 6bfadbb

Please sign in to comment.