From 6111c6dce0843d129ecb43d38498a4c1b202e80e Mon Sep 17 00:00:00 2001 From: Amrutha Shanbhag <70739259+amrutha-shanbhag@users.noreply.github.com> Date: Fri, 28 Jan 2022 23:25:13 +1100 Subject: [PATCH] Implement Protobuf support (#296) Protobuf support --- .github/workflows/tests.yml | 6 + karapace/compatibility/__init__.py | 27 + karapace/compatibility/protobuf/__init__.py | 0 karapace/compatibility/protobuf/checks.py | 33 + karapace/config.py | 3 +- karapace/kafka_rest_apis/__init__.py | 8 +- karapace/kafka_rest_apis/consumer_manager.py | 4 +- karapace/protobuf/__init__.py | 0 karapace/protobuf/compare_result.py | 78 + karapace/protobuf/compare_type_storage.py | 134 + karapace/protobuf/encoding_variants.py | 68 + karapace/protobuf/enum_constant_element.py | 26 + karapace/protobuf/enum_element.py | 68 + karapace/protobuf/exception.py | 48 + karapace/protobuf/extend_element.py | 27 + karapace/protobuf/extensions_element.py | 39 + karapace/protobuf/field.py | 12 + karapace/protobuf/field_element.py | 151 + karapace/protobuf/group_element.py | 32 + karapace/protobuf/io.py | 134 + karapace/protobuf/kotlin_wrapper.py | 32 + karapace/protobuf/location.py | 60 + karapace/protobuf/message_element.py | 142 + karapace/protobuf/one_of_element.py | 55 + karapace/protobuf/option_element.py | 92 + karapace/protobuf/option_reader.py | 153 + karapace/protobuf/proto_file_element.py | 162 + karapace/protobuf/proto_parser.py | 589 ++++ karapace/protobuf/proto_type.py | 221 ++ karapace/protobuf/protobuf_to_dict.py | 335 +++ karapace/protobuf/reserved_element.py | 35 + karapace/protobuf/rpc_element.py | 48 + karapace/protobuf/schema.py | 162 + karapace/protobuf/service_element.py | 33 + karapace/protobuf/syntax.py | 20 + karapace/protobuf/syntax_reader.py | 364 +++ karapace/protobuf/type_element.py | 32 + karapace/protobuf/utils.py | 67 + karapace/rapu.py | 8 +- karapace/schema_reader.py | 38 +- karapace/schema_registry_apis.py | 20 +- karapace/serialization.py | 34 +- requirements.txt | 4 +- runtime/.gitignore | 4 + tests/integration/test_client_protobuf.py | 34 + .../test_rest_consumer_protobuf.py | 77 + tests/integration/test_schema.py | 3 +- tests/integration/test_schema_protobuf.py | 193 ++ tests/schemas/protobuf.py | 73 + tests/unit/conftest.py | 28 +- tests/unit/test_any_tool.py | 70 + tests/unit/test_compare_elements.py | 76 + tests/unit/test_compatibility.py | 186 ++ tests/unit/test_enum_element.py | 139 + tests/unit/test_extend_element.py | 105 + tests/unit/test_extensions_element.py | 37 + tests/unit/test_field_element.py | 86 + tests/unit/test_message_element.py | 448 +++ tests/unit/test_option_element.py | 57 + tests/unit/test_parsing_tester.py | 31 + tests/unit/test_proto_file_element.py | 469 +++ tests/unit/test_proto_parser.py | 2610 +++++++++++++++++ tests/unit/test_protobuf_schema.py | 288 ++ tests/unit/test_protobuf_serialization.py | 72 + tests/unit/test_serialization.py | 3 + tests/unit/test_service_element.py | 151 + tests/utils.py | 122 +- 67 files changed, 8942 insertions(+), 24 deletions(-) create mode 100644 karapace/compatibility/protobuf/__init__.py create mode 100644 karapace/compatibility/protobuf/checks.py create mode 100644 karapace/protobuf/__init__.py create mode 100644 karapace/protobuf/compare_result.py create mode 100644 karapace/protobuf/compare_type_storage.py create mode 100644 karapace/protobuf/encoding_variants.py create mode 100644 karapace/protobuf/enum_constant_element.py create mode 100644 karapace/protobuf/enum_element.py create mode 100644 karapace/protobuf/exception.py create mode 100644 karapace/protobuf/extend_element.py create mode 100644 karapace/protobuf/extensions_element.py create mode 100644 karapace/protobuf/field.py create mode 100644 karapace/protobuf/field_element.py create mode 100644 karapace/protobuf/group_element.py create mode 100644 karapace/protobuf/io.py create mode 100644 karapace/protobuf/kotlin_wrapper.py create mode 100644 karapace/protobuf/location.py create mode 100644 karapace/protobuf/message_element.py create mode 100644 karapace/protobuf/one_of_element.py create mode 100644 karapace/protobuf/option_element.py create mode 100644 karapace/protobuf/option_reader.py create mode 100644 karapace/protobuf/proto_file_element.py create mode 100644 karapace/protobuf/proto_parser.py create mode 100644 karapace/protobuf/proto_type.py create mode 100644 karapace/protobuf/protobuf_to_dict.py create mode 100644 karapace/protobuf/reserved_element.py create mode 100644 karapace/protobuf/rpc_element.py create mode 100644 karapace/protobuf/schema.py create mode 100644 karapace/protobuf/service_element.py create mode 100644 karapace/protobuf/syntax.py create mode 100644 karapace/protobuf/syntax_reader.py create mode 100644 karapace/protobuf/type_element.py create mode 100644 karapace/protobuf/utils.py create mode 100644 runtime/.gitignore create mode 100644 tests/integration/test_client_protobuf.py create mode 100644 tests/integration/test_rest_consumer_protobuf.py create mode 100644 tests/integration/test_schema_protobuf.py create mode 100644 tests/schemas/protobuf.py create mode 100644 tests/unit/test_any_tool.py create mode 100644 tests/unit/test_compare_elements.py create mode 100644 tests/unit/test_compatibility.py create mode 100644 tests/unit/test_enum_element.py create mode 100644 tests/unit/test_extend_element.py create mode 100644 tests/unit/test_extensions_element.py create mode 100644 tests/unit/test_field_element.py create mode 100644 tests/unit/test_message_element.py create mode 100644 tests/unit/test_option_element.py create mode 100644 tests/unit/test_parsing_tester.py create mode 100644 tests/unit/test_proto_file_element.py create mode 100644 tests/unit/test_proto_parser.py create mode 100644 tests/unit/test_protobuf_schema.py create mode 100644 tests/unit/test_protobuf_serialization.py create mode 100644 tests/unit/test_service_element.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5bfd52bb9..2ef29594f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,6 +28,12 @@ jobs: - name: Install dependencies run: python -m pip install -r requirements-dev.txt + - name: Install Protoc + uses: arduino/setup-protoc@v1 + with: + version: '3.13.0' + + - name: Execute unit-tests run: make unittest diff --git a/karapace/compatibility/__init__.py b/karapace/compatibility/__init__.py index 058565b14..af035fcd2 100644 --- a/karapace/compatibility/__init__.py +++ b/karapace/compatibility/__init__.py @@ -12,6 +12,7 @@ SchemaIncompatibilityType ) from karapace.compatibility.jsonschema.checks import compatibility as jsonschema_compatibility +from karapace.compatibility.protobuf.checks import check_protobuf_schema_compatibility from karapace.schema_reader import SchemaType, TypedSchema import logging @@ -63,6 +64,10 @@ def check_jsonschema_compatibility(reader: Draft7Validator, writer: Draft7Valida return jsonschema_compatibility(reader, writer) +def check_protobuf_compatibility(reader, writer) -> SchemaCompatibilityResult: + return check_protobuf_schema_compatibility(reader, writer) + + def check_compatibility( old_schema: TypedSchema, new_schema: TypedSchema, compatibility_mode: CompatibilityModes ) -> SchemaCompatibilityResult: @@ -128,6 +133,28 @@ def check_compatibility( ) ) + elif old_schema.schema_type is SchemaType.PROTOBUF: + if compatibility_mode in {CompatibilityModes.BACKWARD, CompatibilityModes.BACKWARD_TRANSITIVE}: + result = check_protobuf_compatibility( + reader=new_schema.schema, + writer=old_schema.schema, + ) + elif compatibility_mode in {CompatibilityModes.FORWARD, CompatibilityModes.FORWARD_TRANSITIVE}: + result = check_protobuf_compatibility( + reader=old_schema.schema, + writer=new_schema.schema, + ) + + elif compatibility_mode in {CompatibilityModes.FULL, CompatibilityModes.FULL_TRANSITIVE}: + result = check_protobuf_compatibility( + reader=new_schema.schema, + writer=old_schema.schema, + ) + result = result.merged_with(check_protobuf_compatibility( + reader=old_schema.schema, + writer=new_schema.schema, + )) + else: result = SchemaCompatibilityResult.incompatible( incompat_type=SchemaIncompatibilityType.type_mismatch, diff --git a/karapace/compatibility/protobuf/__init__.py b/karapace/compatibility/protobuf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/karapace/compatibility/protobuf/checks.py b/karapace/compatibility/protobuf/checks.py new file mode 100644 index 000000000..54661318e --- /dev/null +++ b/karapace/compatibility/protobuf/checks.py @@ -0,0 +1,33 @@ +from karapace.avro_compatibility import SchemaCompatibilityResult, SchemaCompatibilityType +from karapace.protobuf.compare_result import CompareResult +from karapace.protobuf.schema import ProtobufSchema + +import logging + +log = logging.getLogger(__name__) + + +def check_protobuf_schema_compatibility(reader: ProtobufSchema, writer: ProtobufSchema) -> SchemaCompatibilityResult: + result = CompareResult() + log.debug("READER: %s", reader.to_schema()) + log.debug("WRITER: %s", writer.to_schema()) + writer.compare(reader, result) + log.debug("IS_COMPATIBLE %s", result.is_compatible()) + if result.is_compatible(): + return SchemaCompatibilityResult.compatible() + + incompatibilities = [] + locations = set() + messages = set() + for record in result.result: + if not record.modification.is_compatible(): + incompatibilities.append(record.modification.__str__()) + locations.add(record.path) + messages.add(record.message) + + return SchemaCompatibilityResult( + compatibility=SchemaCompatibilityType.incompatible, + incompatibilities=list(incompatibilities), + locations=set(locations), + messages=set(messages), + ) diff --git a/karapace/config.py b/karapace/config.py index c3736b687..098ad4dce 100644 --- a/karapace/config.py +++ b/karapace/config.py @@ -52,7 +52,8 @@ "session_timeout_ms": 10000, "karapace_rest": False, "karapace_registry": False, - "master_election_strategy": "lowest" + "master_election_strategy": "lowest", + "protobuf_runtime_directory": "runtime", } DEFAULT_LOG_FORMAT_JOURNAL = "%(name)-20s\t%(threadName)s\t%(levelname)-8s\t%(message)s" diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index 3e3c86368..8be851143 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -25,9 +25,9 @@ RECORD_KEYS = ["key", "value", "partition"] PUBLISH_KEYS = {"records", "value_schema", "value_schema_id", "key_schema", "key_schema_id"} RECORD_CODES = [42201, 42202] -KNOWN_FORMATS = {"json", "avro", "binary"} +KNOWN_FORMATS = {"json", "avro", "protobuf", "binary"} OFFSET_RESET_STRATEGIES = {"latest", "earliest"} -SCHEMA_MAPPINGS = {"avro": SchemaType.AVRO, "jsonschema": SchemaType.JSONSCHEMA} +SCHEMA_MAPPINGS = {"avro": SchemaType.AVRO, "jsonschema": SchemaType.JSONSCHEMA, "protobuf": SchemaType.PROTOBUF} TypedConsumer = namedtuple("TypedConsumer", ["consumer", "serialization_format", "config"]) @@ -536,7 +536,7 @@ async def serialize( return json.dumps(obj).encode("utf8") if ser_format == "binary": return base64.b64decode(obj) - if ser_format in {"avro", "jsonschema"}: + if ser_format in {"avro", "jsonschema", "protobuf"}: return await self.schema_serialize(obj, schema_id) raise FormatError(f"Unknown format: {ser_format}") @@ -565,7 +565,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte sub_code=RESTErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value, ) # disallow missing id and schema for any key/value list that has at least one populated element - if formats["embedded_format"] in {"avro", "jsonschema"}: + if formats["embedded_format"] in {"avro", "jsonschema", "protobuf"}: for prefix, code in zip(RECORD_KEYS, RECORD_CODES): if self.all_empty(data, prefix): continue diff --git a/karapace/kafka_rest_apis/consumer_manager.py b/karapace/kafka_rest_apis/consumer_manager.py index c278060f2..d33e0cd72 100644 --- a/karapace/kafka_rest_apis/consumer_manager.py +++ b/karapace/kafka_rest_apis/consumer_manager.py @@ -20,7 +20,7 @@ import time import uuid -KNOWN_FORMATS = {"json", "avro", "binary", "jsonschema"} +KNOWN_FORMATS = {"json", "avro", "binary", "jsonschema", "protobuf"} OFFSET_RESET_STRATEGIES = {"latest", "earliest"} TypedConsumer = namedtuple("TypedConsumer", ["consumer", "serialization_format", "config"]) @@ -481,7 +481,7 @@ async def fetch(self, internal_name: Tuple[str, str], content_type: str, formats async def deserialize(self, bytes_: bytes, fmt: str): if not bytes_: return None - if fmt in {"avro", "jsonschema"}: + if fmt in {"avro", "jsonschema", "protobuf"}: return await self.deserializer.deserialize(bytes_) if fmt == "json": return json.loads(bytes_.decode('utf-8')) diff --git a/karapace/protobuf/__init__.py b/karapace/protobuf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/karapace/protobuf/compare_result.py b/karapace/protobuf/compare_result.py new file mode 100644 index 000000000..6df031962 --- /dev/null +++ b/karapace/protobuf/compare_result.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass, field +from enum import auto, Enum + + +class Modification(Enum): + PACKAGE_ALTER = auto() + SYNTAX_ALTER = auto() + MESSAGE_ADD = auto() + MESSAGE_DROP = auto() + MESSAGE_MOVE = auto() + ENUM_CONSTANT_ADD = auto() + ENUM_CONSTANT_ALTER = auto() + ENUM_CONSTANT_DROP = auto() + ENUM_ADD = auto() + ENUM_DROP = auto() + TYPE_ALTER = auto() + FIELD_ADD = auto() + FIELD_DROP = auto() + FIELD_MOVE = auto() + FIELD_LABEL_ALTER = auto() + FIELD_NAME_ALTER = auto() + FIELD_KIND_ALTER = auto() + FIELD_TYPE_ALTER = auto() + ONE_OF_ADD = auto() + ONE_OF_DROP = auto() + ONE_OF_MOVE = auto() + ONE_OF_FIELD_ADD = auto() + ONE_OF_FIELD_DROP = auto() + ONE_OF_FIELD_MOVE = auto() + FEW_FIELDS_CONVERTED_TO_ONE_OF = auto() + + # protobuf compatibility issues is described in at + # https://yokota.blog/2021/08/26/understanding-protobuf-compatibility/ + def is_compatible(self) -> bool: + return self not in [ + self.MESSAGE_MOVE, self.MESSAGE_DROP, self.FIELD_LABEL_ALTER, self.FIELD_KIND_ALTER, self.FIELD_TYPE_ALTER, + self.ONE_OF_FIELD_DROP, self.FEW_FIELDS_CONVERTED_TO_ONE_OF + ] + + +@dataclass +class ModificationRecord: + modification: Modification + path: str + message: str = field(init=False) + + def __post_init__(self) -> None: + if self.modification.is_compatible(): + self.message = f"Compatible modification {self.modification} found" + else: + self.message = f"Incompatible modification {self.modification} found" + + def to_str(self) -> str: + return self.message + + +class CompareResult: + def __init__(self) -> None: + self.result = [] + self.path = [] + self.canonical_name = [] + + def push_path(self, name_element: str, canonical: bool = False) -> None: + if canonical: + self.canonical_name.append(name_element) + self.path.append(name_element) + + def pop_path(self, canonical: bool = False) -> None: + if canonical: + self.canonical_name.pop() + self.path.pop() + + def add_modification(self, modification: Modification) -> None: + record = ModificationRecord(modification, ".".join(self.path)) + self.result.append(record) + + def is_compatible(self) -> bool: + return all(record.modification.is_compatible() for record in self.result) diff --git a/karapace/protobuf/compare_type_storage.py b/karapace/protobuf/compare_type_storage.py new file mode 100644 index 000000000..8a5263a2a --- /dev/null +++ b/karapace/protobuf/compare_type_storage.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass +from karapace.protobuf.compare_result import CompareResult +from karapace.protobuf.exception import IllegalArgumentException +from karapace.protobuf.proto_type import ProtoType +from karapace.protobuf.type_element import TypeElement +from typing import Dict, List, Optional, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from karapace.protobuf.message_element import MessageElement + from karapace.protobuf.field_element import FieldElement + + +def compute_name(t: ProtoType, result_path: List[str], package_name: str, types: dict) -> Optional[str]: + string = t.string + + if string.startswith("."): + name = string[1:] + if types.get(name): + return name + return None + canonical_name = list(result_path) + if package_name: + canonical_name.insert(0, package_name) + while len(canonical_name) > 0: + pretender: str = ".".join(canonical_name) + "." + string + pt = types.get(pretender) + if pt is not None: + return pretender + canonical_name.pop() + if types.get(string): + return string + return None + + +class CompareTypes: + def __init__(self, self_package_name: str, other_package_name: str, result: CompareResult) -> None: + + self.self_package_name = self_package_name + self.other_package_name = other_package_name + self.self_types: Dict[str, Union[TypeRecord, TypeRecordMap]] = {} + self.other_types: Dict[str, Union[TypeRecord, TypeRecordMap]] = {} + self.locked_messages: List['MessageElement'] = [] + self.environment: List['MessageElement'] = [] + self.result = result + + def add_a_type(self, prefix: str, package_name: str, type_element: TypeElement, types: dict) -> None: + name: str + if prefix: + name = prefix + "." + type_element.name + else: + name = type_element.name + from karapace.protobuf.message_element import MessageElement + from karapace.protobuf.field_element import FieldElement + + if isinstance(type_element, MessageElement): # add support of MapEntry messages + if "map_entry" in type_element.options: + key: Optional[FieldElement] = next((f for f in type_element.fields if f.name == "key"), None) + value: Optional[FieldElement] = next((f for f in type_element.fields if f.name == "value"), None) + types[name] = TypeRecordMap(package_name, type_element, key, value) + else: + types[name] = TypeRecord(package_name, type_element) + else: + types[name] = TypeRecord(package_name, type_element) + + for t in type_element.nested_types: + self.add_a_type(name, package_name, t, types) + + def add_self_type(self, package_name: str, type_element: TypeElement) -> None: + self.add_a_type(package_name, package_name, type_element, self.self_types) + + def add_other_type(self, package_name: str, type_element: TypeElement) -> None: + self.add_a_type(package_name, package_name, type_element, self.other_types) + + def get_self_type(self, t: ProtoType) -> Union[None, 'TypeRecord', 'TypeRecordMap']: + name = compute_name(t, self.result.path, self.self_package_name, self.self_types) + if name is not None: + type_record = self.self_types.get(name) + return type_record + return None + + def get_other_type(self, t: ProtoType) -> Union[None, 'TypeRecord', 'TypeRecordMap']: + name = compute_name(t, self.result.path, self.other_package_name, self.other_types) + if name is not None: + type_record = self.other_types.get(name) + return type_record + return None + + def self_type_short_name(self, t: ProtoType) -> Optional[str]: + name = compute_name(t, self.result.path, self.self_package_name, self.self_types) + if name is None: + raise IllegalArgumentException(f"Cannot determine message type {t}") + type_record: TypeRecord = self.self_types.get(name) + if name.startswith(type_record.package_name): + return name[(len(type_record.package_name) + 1):] + return name + + def other_type_short_name(self, t: ProtoType) -> Optional[str]: + name = compute_name(t, self.result.path, self.other_package_name, self.other_types) + if name is None: + raise IllegalArgumentException(f"Cannot determine message type {t}") + type_record: TypeRecord = self.other_types.get(name) + if name.startswith(type_record.package_name): + return name[(len(type_record.package_name) + 1):] + return name + + def lock_message(self, message: 'MessageElement') -> bool: + if message in self.locked_messages: + return False + self.locked_messages.append(message) + return True + + def unlock_message(self, message: 'MessageElement') -> bool: + if message in self.locked_messages: + self.locked_messages.remove(message) + return True + return False + + +@dataclass +class TypeRecord: + package_name: str + type_element: TypeElement + + +class TypeRecordMap(TypeRecord): + def __init__( + self, package_name: str, type_element: TypeElement, key: Optional['FieldElement'], value: Optional['FieldElement'] + ) -> None: + super().__init__(package_name, type_element) + self.key = key + self.value = value + + def map_type(self) -> ProtoType: + return ProtoType.get2(f"map<{self.key.element_type}, {self.value.element_type}>") diff --git a/karapace/protobuf/encoding_variants.py b/karapace/protobuf/encoding_variants.py new file mode 100644 index 000000000..3511f8195 --- /dev/null +++ b/karapace/protobuf/encoding_variants.py @@ -0,0 +1,68 @@ +# Workaround to encode/decode indexes in protobuf messages +# Based on https://developers.google.com/protocol-buffers/docs/encoding#varints + +from io import BytesIO +from karapace.protobuf.exception import IllegalArgumentException +from typing import List + +ZERO_BYTE = b'\x00' + + +def read_varint(bio: BytesIO) -> int: + """Read a variable-length integer. + """ + varint = 0 + read_bytes = 0 + + while True: + char = bio.read(1) + if len(char) == 0: + if read_bytes == 0: + return 0 + raise EOFError(f"EOF while reading varint, value is {varint} so far") + + byte = ord(char) + varint += (byte & 0x7F) << (7 * read_bytes) + + read_bytes += 1 + + if not byte & 0x80: + return varint + + +def read_indexes(bio: BytesIO) -> List[int]: + try: + size: int = read_varint(bio) + except EOFError: + # TODO: change exception + raise IllegalArgumentException("problem with reading binary data") + if size == 0: + return [0] + return [read_varint(bio) for _ in range(size)] + + +def write_varint(bio: BytesIO, value: int) -> int: + if value < 0: + raise ValueError(f"value must not be negative, got {value}") + + if value == 0: + bio.write(ZERO_BYTE) + return 1 + + written_bytes = 0 + while value > 0: + to_write = value & 0x7f + value = value >> 7 + + if value > 0: + to_write |= 0x80 + + bio.write(bytearray(to_write)[0]) + written_bytes += 1 + + return written_bytes + + +def write_indexes(bio: BytesIO, indexes: List[int]) -> None: + for i in indexes: + write_varint(bio, i) diff --git a/karapace/protobuf/enum_constant_element.py b/karapace/protobuf/enum_constant_element.py new file mode 100644 index 000000000..b49f80512 --- /dev/null +++ b/karapace/protobuf/enum_constant_element.py @@ -0,0 +1,26 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/EnumConstantElement.kt +from attr import dataclass +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.utils import append_documentation, append_options +from typing import List + + +@dataclass +class EnumConstantElement: + location: Location + name: str + tag: int + documentation: str = "" + options: List[OptionElement] = [] + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append(f"{self.name} = {self.tag}") + if self.options: + result.append(" ") + append_options(result, self.options) + result.append(";\n") + return "".join(result) diff --git a/karapace/protobuf/enum_element.py b/karapace/protobuf/enum_element.py new file mode 100644 index 000000000..72a84b5be --- /dev/null +++ b/karapace/protobuf/enum_element.py @@ -0,0 +1,68 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/EnumElement.kt +from itertools import chain +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import CompareTypes +from karapace.protobuf.enum_constant_element import EnumConstantElement +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.type_element import TypeElement +from karapace.protobuf.utils import append_documentation, append_indented +from typing import List + + +class EnumElement(TypeElement): + def __init__( + self, + location: Location, + name: str, + documentation: str = "", + options: List[OptionElement] = None, + constants: List[EnumConstantElement] = None + ) -> None: + # Enums do not allow nested type declarations. + super().__init__(location, name, documentation, options or [], []) + self.constants = constants or [] + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append(f"enum {self.name} {{") + + if self.options or self.constants: + result.append("\n") + + if self.options: + for option in self.options: + append_indented(result, option.to_schema_declaration()) + + if self.constants: + for constant in self.constants: + append_indented(result, constant.to_schema()) + + result.append("}\n") + return "".join(result) + + def compare(self, other: "EnumElement", result: CompareResult, types: CompareTypes) -> None: + self_tags = {} + other_tags = {} + constant: EnumConstantElement + if types: + pass + + for constant in self.constants: + self_tags[constant.tag] = constant + + for constant in other.constants: + other_tags[constant.tag] = constant + + for tag in chain(self_tags.keys(), other_tags.keys() - self_tags.keys()): + result.push_path(str(tag)) + if self_tags.get(tag) is None: + result.add_modification(Modification.ENUM_CONSTANT_ADD) + elif other_tags.get(tag) is None: + result.add_modification(Modification.ENUM_CONSTANT_DROP) + else: + if self_tags.get(tag).name == other_tags.get(tag).name: + result.add_modification(Modification.ENUM_CONSTANT_ALTER) + result.pop_path() diff --git a/karapace/protobuf/exception.py b/karapace/protobuf/exception.py new file mode 100644 index 000000000..43baa95fe --- /dev/null +++ b/karapace/protobuf/exception.py @@ -0,0 +1,48 @@ +import json + + +class ProtobufParserRuntimeException(Exception): + pass + + +class IllegalStateException(Exception): + def __init__(self, message="IllegalStateException") -> None: + self.message = message + super().__init__(self.message) + + +class IllegalArgumentException(Exception): + def __init__(self, message="IllegalArgumentException") -> None: + self.message = message + super().__init__(self.message) + + +class Error(Exception): + """Base class for errors in this module.""" + + +class ProtobufException(Error): + """Generic Protobuf schema error.""" + + +class ProtobufTypeException(Error): + """Generic Protobuf type error.""" + + +class SchemaParseException(ProtobufException): + """Error while parsing a Protobuf schema descriptor.""" + + +def pretty_print_json(obj: str) -> str: + return json.dumps(json.loads(obj), indent=2) + + +class ProtobufSchemaResolutionException(ProtobufException): + def __init__(self, fail_msg: str, writer_schema=None, reader_schema=None) -> None: + writer_dump = pretty_print_json(str(writer_schema)) + reader_dump = pretty_print_json(str(reader_schema)) + if writer_schema: + fail_msg += "\nWriter's Schema: %s" % writer_dump + if reader_schema: + fail_msg += "\nReader's Schema: %s" % reader_dump + super().__init__(self, fail_msg) diff --git a/karapace/protobuf/extend_element.py b/karapace/protobuf/extend_element.py new file mode 100644 index 000000000..8f49c765f --- /dev/null +++ b/karapace/protobuf/extend_element.py @@ -0,0 +1,27 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ExtendElement.kt +from dataclasses import dataclass +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.location import Location +from karapace.protobuf.utils import append_documentation, append_indented +from typing import List + + +@dataclass +class ExtendElement: + location: Location + name: str + documentation: str = "" + fields: List[FieldElement] = None + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append(f"extend {self.name} {{") + if self.fields: + result.append("\n") + for field in self.fields: + append_indented(result, field.to_schema()) + + result.append("}\n") + return "".join(result) diff --git a/karapace/protobuf/extensions_element.py b/karapace/protobuf/extensions_element.py new file mode 100644 index 000000000..45afb2e40 --- /dev/null +++ b/karapace/protobuf/extensions_element.py @@ -0,0 +1,39 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ExtensionsElement.kt +from dataclasses import dataclass +from karapace.protobuf.kotlin_wrapper import KotlinRange +from karapace.protobuf.location import Location +from karapace.protobuf.utils import append_documentation, MAX_TAG_VALUE +from typing import List, Union + + +@dataclass +class ExtensionsElement: + location: Location + documentation: str = "" + values: List[Union[int, KotlinRange]] = None + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append("extensions ") + + for index in range(0, len(self.values)): + value = self.values[index] + if index > 0: + result.append(", ") + if isinstance(value, int): + result.append(str(value)) + # TODO: maybe replace Kotlin IntRange by list? + elif isinstance(value, KotlinRange): + result.append(f"{value.minimum} to ") + last_value = value.maximum + if last_value < MAX_TAG_VALUE: + result.append(str(last_value)) + else: + result.append("max") + else: + raise AssertionError() + + result.append(";\n") + return "".join(result) diff --git a/karapace/protobuf/field.py b/karapace/protobuf/field.py new file mode 100644 index 000000000..c24a25257 --- /dev/null +++ b/karapace/protobuf/field.py @@ -0,0 +1,12 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Field.kt + +from enum import Enum + + +class Field: + class Label(Enum): + OPTIONAL = 1 + REQUIRED = 2 + REPEATED = 3 + ONE_OF = 4 diff --git a/karapace/protobuf/field_element.py b/karapace/protobuf/field_element.py new file mode 100644 index 000000000..58766afa8 --- /dev/null +++ b/karapace/protobuf/field_element.py @@ -0,0 +1,151 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/FieldElement.kt +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import TypeRecordMap +from karapace.protobuf.field import Field +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.proto_type import ProtoType +from karapace.protobuf.utils import append_documentation, append_options +from typing import List + + +class FieldElement: + from karapace.protobuf.compare_type_storage import CompareTypes + + def __init__( + self, + location: Location, + label: Field.Label = None, + element_type: str = "", + name: str = None, + default_value: str = None, + json_name: str = None, + tag: int = None, + documentation: str = "", + options: list = None + ) -> None: + self.location = location + self.label = label + self.element_type = element_type + self.name = name + self.default_value = default_value + self.json_name = json_name + self.tag = tag + self.documentation = documentation + self.options = options or [] + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + + if self.label: + result.append(f"{self.label.name.lower()} ") + + result.append(f"{self.element_type} {self.name} = {self.tag}") + + options_with_default = self.options_with_special_values() + if options_with_default: + result.append(' ') + append_options(result, options_with_default) + result.append(";\n") + + return "".join(result) + + def options_with_special_values(self) -> List[OptionElement]: + """ Both `default` and `json_name` are defined in the schema like options but they are actually + not options themselves as they're missing from `google.protobuf.FieldOptions`. + """ + + options: list = self.options.copy() + + if self.default_value: + proto_type: ProtoType = ProtoType.get2(self.element_type) + options.append(OptionElement("default", proto_type.to_kind(), self.default_value, False)) + + if self.json_name: + options.append(OptionElement("json_name", OptionElement.Kind.STRING, self.json_name, False)) + + return options + + # Only non-repeated scalar types and Enums support default values. + + def compare(self, other: "FieldElement", result: CompareResult, types: CompareTypes) -> None: + + if self.name != other.name: + result.add_modification(Modification.FIELD_NAME_ALTER) + + self.compare_type(ProtoType.get2(self.element_type), ProtoType.get2(other.element_type), other.label, result, types) + + def compare_map(self, self_map: ProtoType, other_map: ProtoType, result: CompareResult, types: CompareTypes) -> None: + self.compare_type(self_map.key_type, other_map.key_type, "", result, types) + self.compare_type(self_map.value_type, other_map.value_type, "", result, types) + + def compare_type( + self, self_type: ProtoType, other_type: ProtoType, other_label: str, result: CompareResult, types: CompareTypes + ) -> None: + from karapace.protobuf.enum_element import EnumElement + self_type_record = types.get_self_type(self_type) + other_type_record = types.get_other_type(other_type) + self_is_scalar: bool = False + other_is_scalar: bool = False + + if isinstance(self_type_record, TypeRecordMap): + self_type = self_type_record.map_type() + + if isinstance(other_type_record, TypeRecordMap): + other_type = other_type_record.map_type() + + self_is_enum: bool = False + other_is_enum: bool = False + + if self_type_record and isinstance(self_type_record.type_element, EnumElement): + self_is_enum = True + + if other_type_record and isinstance(other_type_record.type_element, EnumElement): + other_is_enum = True + + if self_type.is_scalar or self_is_enum: + self_is_scalar = True + + if other_type.is_scalar or other_is_enum: + other_is_scalar = True + if self_is_scalar == other_is_scalar and \ + self_type.is_map == other_type.is_map: + if self_type.is_map: + self.compare_map(self_type, other_type, result, types) + elif self_is_scalar: + self_compatibility_kind = self_type.compatibility_kind(self_is_enum) + other_compatibility_kind = other_type.compatibility_kind(other_is_enum) + if other_label == '': + other_label = None + if self.label != other_label \ + and self_compatibility_kind in \ + [ProtoType.CompatibilityKind.VARIANT, + ProtoType.CompatibilityKind.DOUBLE, + ProtoType.CompatibilityKind.FLOAT, + ProtoType.CompatibilityKind.FIXED64, + ProtoType.CompatibilityKind.FIXED32, + ProtoType.CompatibilityKind.SVARIANT]: + result.add_modification(Modification.FIELD_LABEL_ALTER) + if self_compatibility_kind != other_compatibility_kind: + result.add_modification(Modification.FIELD_KIND_ALTER) + else: + self.compare_message(self_type, other_type, result, types) + else: + result.add_modification(Modification.FIELD_KIND_ALTER) + + @classmethod + def compare_message( + cls, self_type: ProtoType, other_type: ProtoType, result: CompareResult, types: CompareTypes + ) -> None: + from karapace.protobuf.message_element import MessageElement + self_type_record = types.get_self_type(self_type) + other_type_record = types.get_other_type(other_type) + self_type_element: MessageElement = self_type_record.type_element + other_type_element: MessageElement = other_type_record.type_element + + if types.self_type_short_name(self_type) != types.other_type_short_name(other_type): + result.add_modification(Modification.FIELD_NAME_ALTER) + else: + self_type_element.compare(other_type_element, result, types) diff --git a/karapace/protobuf/group_element.py b/karapace/protobuf/group_element.py new file mode 100644 index 000000000..258d81064 --- /dev/null +++ b/karapace/protobuf/group_element.py @@ -0,0 +1,32 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/GroupElement.kt +from dataclasses import dataclass +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.location import Location +from karapace.protobuf.utils import append_documentation, append_indented +from typing import List, Optional + + +@dataclass +class GroupElement: + label: Optional[Field.Label] + location: Location + name: str + tag: int + documentation: str = "" + fields: List[FieldElement] = None + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + + if self.label: + result.append(f"{str(self.label.name).lower()} ") + result.append(f"group {self.name} = {self.tag} {{") + if self.fields: + result.append("\n") + for field in self.fields: + append_indented(result, field.to_schema()) + result.append("}\n") + return "".join(result) diff --git a/karapace/protobuf/io.py b/karapace/protobuf/io.py new file mode 100644 index 000000000..c23c17959 --- /dev/null +++ b/karapace/protobuf/io.py @@ -0,0 +1,134 @@ +from io import BytesIO +from karapace import config +from karapace.protobuf.encoding_variants import read_indexes, write_indexes +from karapace.protobuf.exception import IllegalArgumentException, ProtobufSchemaResolutionException, ProtobufTypeException +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.protobuf_to_dict import dict_to_protobuf, protobuf_to_dict +from karapace.protobuf.schema import ProtobufSchema +from karapace.protobuf.type_element import TypeElement +from typing import Any, Dict, List + +import hashlib +import importlib +import importlib.util +import logging +import os +import subprocess + +logger = logging.getLogger(__name__) + + +def calculate_class_name(name: str) -> str: + return "c_" + hashlib.md5(name.encode('utf-8')).hexdigest() + + +def match_schemas(writer_schema: ProtobufSchema, reader_schema: ProtobufSchema) -> bool: + # TODO (serge): is schema comparison by fields required? + + return str(writer_schema) == str(reader_schema) + + +def find_message_name(schema: ProtobufSchema, indexes: List[int]) -> str: + result: List[str] = [] + types = schema.proto_file_element.types + for index in indexes: + try: + message = types[index] + except IndexError: + raise IllegalArgumentException(f"Invalid message indexes: {indexes}") + + if message and isinstance(message, MessageElement): + result.append(message.name) + types = message.nested_types + else: + raise IllegalArgumentException(f"Invalid message indexes: {indexes}") + + # for java we also need package name. But in case we will use protoc + # for compiling to python we can ignore it at all + return ".".join(result) + + +def get_protobuf_class_instance(schema: ProtobufSchema, class_name: str, cfg: Dict) -> Any: + directory = cfg["protobuf_runtime_directory"] + proto_name = calculate_class_name(str(schema)) + proto_path = f"{directory}/{proto_name}.proto" + class_path = f"{directory}/{proto_name}_pb2.py" + if not os.path.isfile(proto_path): + with open(f"{directory}/{proto_name}.proto", "w") as proto_text: + proto_text.write(str(schema)) + + if not os.path.isfile(class_path): + subprocess.run([ + "protoc", + "--python_out=./", + proto_path, + ], check=True) + + spec = importlib.util.spec_from_file_location(f"{proto_name}_pb2", class_path) + tmp_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(tmp_module) + class_to_call = getattr(tmp_module, class_name) + return class_to_call() + + +def read_data(writer_schema: ProtobufSchema, reader_schema: ProtobufSchema, bio: BytesIO) -> Any: + if not match_schemas(writer_schema, reader_schema): + fail_msg = 'Schemas do not match.' + raise ProtobufSchemaResolutionException(fail_msg, writer_schema, reader_schema) + + indexes = read_indexes(bio) + name = find_message_name(writer_schema, indexes) + + class_instance = get_protobuf_class_instance(writer_schema, name, config.DEFAULTS) + class_instance.ParseFromString(bio.read()) + + return class_instance + + +class ProtobufDatumReader: + """Deserialize Protobuf-encoded data into a Python data structure.""" + + def __init__(self, writer_schema: ProtobufSchema = None, reader_schema: ProtobufSchema = None) -> None: + """ As defined in the Protobuf specification, we call the schema encoded + in the data the "writer's schema", and the schema expected by the + reader the "reader's schema". + """ + self._writer_schema = writer_schema + self._reader_schema = reader_schema + + def read(self, bio: BytesIO) -> None: + if self._reader_schema is None: + self._reader_schema = self._writer_schema + return protobuf_to_dict(read_data(self._writer_schema, self._reader_schema, bio), True) + + +class ProtobufDatumWriter: + """ProtobufDatumWriter for generic python objects.""" + + def __init__(self, writer_schema: ProtobufSchema = None): + self._writer_schema = writer_schema + a: ProtobufSchema = writer_schema + el: TypeElement + self._message_name = "" + for idx, el in enumerate(a.proto_file_element.types): + if isinstance(el, MessageElement): + self._message_name = el.name + self._message_index = idx + break + + if self._message_name == "": + raise ProtobufTypeException("No message in protobuf schema") + + def write_index(self, writer: BytesIO) -> None: + write_indexes(writer, [self._message_index]) + + def write(self, datum: dict, writer: BytesIO) -> None: + + class_instance = get_protobuf_class_instance(self._writer_schema, self._message_name, config.DEFAULTS) + + try: + dict_to_protobuf(class_instance, datum) + except Exception: + raise ProtobufTypeException(self._writer_schema, datum) + + writer.write(class_instance.SerializeToString()) diff --git a/karapace/protobuf/kotlin_wrapper.py b/karapace/protobuf/kotlin_wrapper.py new file mode 100644 index 000000000..044417b57 --- /dev/null +++ b/karapace/protobuf/kotlin_wrapper.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass + +import textwrap + + +def trim_margin(s: str) -> str: + lines = s.split("\n") + new_lines = [] + + if not textwrap.dedent(lines[0]): + del lines[0] + + if not textwrap.dedent(lines[-1]): + del lines[-1] + + for line in lines: + idx = line.find("|") + if idx < 0: + new_lines.append(line) + else: + new_lines.append(line[idx + 1:]) + + return "\n".join(new_lines) + + +@dataclass +class KotlinRange: + minimum: int + maximum: int + + def __str__(self) -> str: + return f"{self.minimum}..{self.maximum}" diff --git a/karapace/protobuf/location.py b/karapace/protobuf/location.py new file mode 100644 index 000000000..a87f0626b --- /dev/null +++ b/karapace/protobuf/location.py @@ -0,0 +1,60 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Location.kt +from typing import Optional + + +class Location: + """ Locates a .proto file, or a self.position within a .proto file, on the file system """ + + def __init__(self, base: str, path: str, line: int = -1, column: int = -1) -> None: + """ str - The base directory of this location; + path - The path to this location relative to [base] + line - The line number of this location, or -1 for no specific line number + column - The column on the line of this location, or -1 for no specific column + """ + self.base = base + self.path = path + self.line = line + self.column = column + + def at(self, line: int, column: int) -> 'Location': + return Location(self.base, self.path, line, column) + + def without_base(self) -> 'Location': + """ Returns a copy of this location with an empty base. """ + return Location("", self.path, self.line, self.column) + + def with_path_only(self) -> 'Location': + """ Returns a copy of this location including only its path. """ + return Location("", self.path, -1, -1) + + def __str__(self) -> str: + result = "" + if self.base: + result += self.base + "/" + + result += self.path + + if self.line != -1: + result += ":" + result += str(self.line) + if self.column != -1: + result += ":" + result += str(self.column) + + return result + + @staticmethod + def get(*args) -> Optional['Location']: + result = None + if len(args) == 1: # (path) + path = args[0] + result = Location.get("", path) + if len(args) == 2: # (base,path) + path: str = args[1] + base: str = args[0] + if base.endswith("/"): + base = base[:-1] + result = Location(base, path) + + return result diff --git a/karapace/protobuf/message_element.py b/karapace/protobuf/message_element.py new file mode 100644 index 000000000..5c7a460ab --- /dev/null +++ b/karapace/protobuf/message_element.py @@ -0,0 +1,142 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/MessageElement.kt +# compatibility routine added +from itertools import chain +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import CompareTypes +from karapace.protobuf.extensions_element import ExtensionsElement +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.group_element import GroupElement +from karapace.protobuf.location import Location +from karapace.protobuf.one_of_element import OneOfElement +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.reserved_element import ReservedElement +from karapace.protobuf.type_element import TypeElement +from karapace.protobuf.utils import append_documentation, append_indented +from typing import List + + +class MessageElement(TypeElement): + def __init__( + self, + location: Location, + name: str, + documentation: str = "", + nested_types: List[str] = None, + options: List[OptionElement] = None, + reserveds: List[ReservedElement] = None, + fields: List[FieldElement] = None, + one_ofs: List[OneOfElement] = None, + extensions: List[ExtensionsElement] = None, + groups: List[GroupElement] = None, + ) -> None: + super().__init__(location, name, documentation, options or [], nested_types or []) + self.reserveds = reserveds or [] + self.fields = fields or [] + self.one_ofs = one_ofs or [] + self.extensions = extensions or [] + self.groups = groups or [] + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append(f"message {self.name} {{") + if self.reserveds: + result.append("\n") + for reserved in self.reserveds: + append_indented(result, reserved.to_schema()) + + if self.options: + result.append("\n") + for option in self.options: + append_indented(result, option.to_schema_declaration()) + + if self.fields: + for field in self.fields: + result.append("\n") + append_indented(result, field.to_schema()) + + if self.one_ofs: + for one_of in self.one_ofs: + result.append("\n") + append_indented(result, one_of.to_schema()) + + if self.groups: + for group in self.groups: + result.append("\n") + append_indented(result, group.to_schema()) + + if self.extensions: + result.append("\n") + for extension in self.extensions: + append_indented(result, extension.to_schema()) + + if self.nested_types: + result.append("\n") + for nested_type in self.nested_types: + append_indented(result, nested_type.to_schema()) + + result.append("}\n") + return "".join(result) + + def compare(self, other: 'MessageElement', result: CompareResult, types: CompareTypes) -> None: + + if types.lock_message(self): + field: FieldElement + subfield: FieldElement + one_of: OneOfElement + self_tags = {} + other_tags = {} + self_one_ofs = {} + other_one_ofs = {} + + for field in self.fields: + self_tags[field.tag] = field + + for field in other.fields: + other_tags[field.tag] = field + + for one_of in self.one_ofs: + self_one_ofs[one_of.name] = one_of + + for one_of in other.one_ofs: + other_one_ofs[one_of.name] = one_of + + for field in other.one_ofs: + result.push_path(str(field.name)) + convert_count = 0 + for subfield in field.fields: + tag = subfield.tag + if self_tags.get(tag): + self_tags.pop(tag) + convert_count += 1 + if convert_count > 1: + result.add_modification(Modification.FEW_FIELDS_CONVERTED_TO_ONE_OF) + result.pop_path() + + # Compare fields + for tag in chain(self_tags.keys(), other_tags.keys() - self_tags.keys()): + result.push_path(str(tag)) + + if self_tags.get(tag) is None: + result.add_modification(Modification.FIELD_ADD) + elif other_tags.get(tag) is None: + result.add_modification(Modification.FIELD_DROP) + else: + self_tags[tag].compare(other_tags[tag], result, types) + + result.pop_path() + # Compare OneOfs + for name in chain(self_one_ofs.keys(), other_one_ofs.keys() - self_one_ofs.keys()): + result.push_path(str(name)) + + if self_one_ofs.get(name) is None: + result.add_modification(Modification.ONE_OF_ADD) + elif other_one_ofs.get(name) is None: + result.add_modification(Modification.ONE_OF_DROP) + else: + self_one_ofs[name].compare(other_one_ofs[name], result, types) + + result.pop_path() + + types.unlock_message(self) diff --git a/karapace/protobuf/one_of_element.py b/karapace/protobuf/one_of_element.py new file mode 100644 index 000000000..0f1f16a61 --- /dev/null +++ b/karapace/protobuf/one_of_element.py @@ -0,0 +1,55 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/OneOfElement.kt +from itertools import chain +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import CompareTypes +from karapace.protobuf.utils import append_documentation, append_indented + + +class OneOfElement: + def __init__(self, name: str, documentation: str = "", fields=None, groups=None, options=None) -> None: + self.name = name + self.documentation = documentation + self.fields = fields or [] + self.options = options or [] + self.groups = groups or [] + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append(f"oneof {self.name} {{") + if self.options: + result.append("\n") + for option in self.options: + append_indented(result, option.to_schema_declaration()) + + if self.fields: + result.append("\n") + for field in self.fields: + append_indented(result, field.to_schema()) + if self.groups: + result.append("\n") + for group in self.groups: + append_indented(result, group.to_schema()) + result.append("}\n") + return "".join(result) + + def compare(self, other: 'OneOfElement', result: CompareResult, types: CompareTypes) -> None: + self_tags = {} + other_tags = {} + + for field in self.fields: + self_tags[field.tag] = field + for field in other.fields: + other_tags[field.tag] = field + + for tag in chain(self_tags.keys(), other_tags.keys() - self_tags.keys()): + result.push_path(str(tag)) + + if self_tags.get(tag) is None: + result.add_modification(Modification.ONE_OF_FIELD_ADD) + elif other_tags.get(tag) is None: + result.add_modification(Modification.ONE_OF_FIELD_DROP) + else: + self_tags[tag].compare(other_tags[tag], result, types) + result.pop_path() diff --git a/karapace/protobuf/option_element.py b/karapace/protobuf/option_element.py new file mode 100644 index 000000000..8c6228bcf --- /dev/null +++ b/karapace/protobuf/option_element.py @@ -0,0 +1,92 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/OptionElement.kt + +from enum import Enum +from karapace.protobuf.utils import append_indented, append_options, try_to_schema + + +class OptionElement: + class Kind(Enum): + STRING = 1 + BOOLEAN = 2 + NUMBER = 3 + ENUM = 4 + MAP = 5 + LIST = 6 + OPTION = 7 + + def __init__(self, name: str, kind: Kind, value, is_parenthesized: bool = None) -> None: + self.name = name + self.kind = kind + self.value = value + """ If true, this [OptionElement] is a custom option. """ + self.is_parenthesized = is_parenthesized or False + self.formattedName = f"({self.name})" if is_parenthesized else self.name + + def to_schema(self) -> str: + aline = None + if self.kind == self.Kind.STRING: + aline = f"{self.formattedName} = \"{self.value}\"" + elif self.kind in [self.Kind.BOOLEAN, self.Kind.NUMBER, self.Kind.ENUM]: + aline = f"{self.formattedName} = {self.value}" + elif self.kind == self.Kind.OPTION: + aline = f"{self.formattedName}.{try_to_schema(self.value)}" + elif self.kind == self.Kind.MAP: + aline = [f"{self.formattedName} = {{\n", self.format_option_map(self.value), "}"] + elif self.kind == self.Kind.LIST: + aline = [f"{self.formattedName} = ", self.append_options(self.value)] + + if isinstance(aline, list): + return "".join(aline) + return aline + + def to_schema_declaration(self) -> str: + return f"option {self.to_schema()};\n" + + @staticmethod + def append_options(options: list) -> str: + data = [] + append_options(data, options) + return "".join(data) + + def format_option_map(self, value: dict) -> str: + keys = list(value.keys()) + last_index = len(keys) - 1 + result = [] + for index, key in enumerate(keys): + endl = "," if (index != last_index) else "" + append_indented(result, f"{key}: {self.format_option_map_value(value[key])}{endl}") + return "".join(result) + + def format_option_map_value(self, value) -> str: + aline = value + if isinstance(value, str): + aline = f"\"{value}\"" + elif isinstance(value, dict): + aline = ["{\n", self.format_option_map(value), "}"] + elif isinstance(value, list): + aline = ["[\n", self.format_list_map_value(value), "]"] + + if isinstance(aline, list): + return "".join(aline) + if isinstance(aline, str): + return aline + return value + + def format_list_map_value(self, value) -> str: + + last_index = len(value) - 1 + result = [] + for index, elm in enumerate(value): + endl = "," if (index != last_index) else "" + append_indented(result, f"{self.format_option_map_value(elm)}{endl}") + return "".join(result) + + def __repr__(self) -> str: + return self.to_schema() + + def __eq__(self, other) -> bool: + return str(self) == str(other) + + +PACKED_OPTION_ELEMENT = OptionElement("packed", OptionElement.Kind.BOOLEAN, value="true", is_parenthesized=False) diff --git a/karapace/protobuf/option_reader.py b/karapace/protobuf/option_reader.py new file mode 100644 index 000000000..1003dcdde --- /dev/null +++ b/karapace/protobuf/option_reader.py @@ -0,0 +1,153 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/OptionReader.kt +from dataclasses import dataclass +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.syntax_reader import SyntaxReader +from typing import Union + + +@dataclass +class KindAndValue: + kind: OptionElement.Kind + value: object + + +class OptionReader: + reader: SyntaxReader + + def __init__(self, reader: SyntaxReader) -> None: + self.reader = reader + + def read_options(self) -> list: + """ Reads options enclosed in '[' and ']' if they are present and returns them. Returns an empty + list if no options are present. + """ + if not self.reader.peek_char('['): + return [] + result = [] + while True: + result.append(self.read_option('=')) + + # Check for closing ']' + if self.reader.peek_char(']'): + break + + # Discard optional ','. + self.reader.expect(self.reader.peek_char(','), "Expected ',' or ']") + return result + + def read_option(self, key_value_separator: str) -> OptionElement: + """ Reads a option containing a name, an '=' or ':', and a value. + """ + is_extension = (self.reader.peek_char() == '[') + is_parenthesized = (self.reader.peek_char() == '(') + name = self.reader.read_name() # Option name. + if is_extension: + name = f"[{name}]" + + sub_names = [] + c = self.reader.read_char() + if c == '.': + # Read nested field name. For example "baz" in "(foo.bar).baz = 12". + sub_names = self.reader.read_name().split(".") + c = self.reader.read_char() + + if key_value_separator == ':' and c == '{': + # In text format, values which are maps can omit a separator. Backtrack so it can be re-read. + self.reader.push_back('{') + else: + self.reader.expect(c == key_value_separator, f"expected '{key_value_separator}' in option") + + kind_and_value = self.read_kind_and_value() + kind = kind_and_value.kind + value = kind_and_value.value + sub_names.reverse() + for sub_name in sub_names: + value = OptionElement(sub_name, kind, value, False) + kind = OptionElement.Kind.OPTION + return OptionElement(name, kind, value, is_parenthesized) + + def read_kind_and_value(self) -> KindAndValue: + """ Reads a value that can be a map, list, string, number, boolean or enum. """ + peeked = self.reader.peek_char() + result: KindAndValue + if peeked == '{': + result = KindAndValue(OptionElement.Kind.MAP, self.read_map('{', '}', ':')) + elif peeked == '[': + result = KindAndValue(OptionElement.Kind.LIST, self.read_list()) + elif peeked in ('"', "'"): + result = KindAndValue(OptionElement.Kind.STRING, self.reader.read_string()) + elif ord(str(peeked)) in range(ord("0"), ord("9")) or peeked == '-': + result = KindAndValue(OptionElement.Kind.NUMBER, self.reader.read_word()) + else: + word = self.reader.read_word() + if word == "true": + result = KindAndValue(OptionElement.Kind.BOOLEAN, "true") + elif word == "false": + result = KindAndValue(OptionElement.Kind.BOOLEAN, "false") + else: + result = KindAndValue(OptionElement.Kind.ENUM, word) + return result + + def read_map(self, open_brace: str, close_brace: str, key_value_separator: str) -> dict: + """ Returns a map of string keys and values. This is similar to a JSON object, with ':' and '}' + surrounding the map, ':' separating keys from values, and ',' or ';' separating entries. + """ + if self.reader.read_char() != open_brace: + raise AssertionError() + result = {} + while True: + if self.reader.peek_char(close_brace): + # If we see the close brace, finish immediately. This handles :}/[] and ,}/,] cases. + return result + + option = self.read_option(key_value_separator) + name = option.name + value = option.value + if isinstance(value, OptionElement): + nested = result.get(name) + if not nested: + nested = {} + result[name] = nested + nested[value.name] = value.value + else: + # Add the value(s) to any previous values with the same key + previous = result.get(name) + if not previous: + result[name] = value + elif isinstance(previous, list): # Add to previous List + self.add_to_list(previous, value) + else: + new_list = [] + new_list.append(previous) + self.add_to_list(new_list, value) + result[name] = new_list + # Discard optional separator. + if not self.reader.peek_char(','): + self.reader.peek_char(';') + + @staticmethod + def add_to_list(_list: list, value: Union[list, str]) -> None: + """ Adds an object or objects to a List. """ + if isinstance(value, list): + for v in list(value): + _list.append(v) + else: + _list.append(value) + + def read_list(self) -> list: + """ Returns a list of values. This is similar to JSON with '[' and ']' surrounding the list and ',' + separating values. + """ + self.reader.require('[') + result = [] + while True: + # If we see the close brace, finish immediately. This handles [] and ,] cases. + if self.reader.peek_char(']'): + return result + + result.append(self.read_kind_and_value().value) + + if self.reader.peek_char(','): + continue + self.reader.expect(self.reader.peek_char() == ']', "expected ',' or ']'") diff --git a/karapace/protobuf/proto_file_element.py b/karapace/protobuf/proto_file_element.py new file mode 100644 index 000000000..c798104fb --- /dev/null +++ b/karapace/protobuf/proto_file_element.py @@ -0,0 +1,162 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ProtoFileElement.kt +from itertools import chain +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import CompareTypes +from karapace.protobuf.enum_element import EnumElement +from karapace.protobuf.exception import IllegalStateException +from karapace.protobuf.location import Location +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.syntax import Syntax +from karapace.protobuf.type_element import TypeElement + + +class ProtoFileElement: + def __init__( + self, + location: Location, + package_name: str = None, + syntax: Syntax = None, + imports: list = None, + public_imports: list = None, + types=None, + services: list = None, + extend_declarations: list = None, + options: list = None + ) -> None: + if types is None: + types = [] + self.location = location + self.package_name = package_name + self.syntax = syntax + self.options = options or [] + self.extend_declarations = extend_declarations or [] + self.services = services or [] + self.types = types or [] + self.public_imports = public_imports or [] + self.imports = imports or [] + + def to_schema(self) -> str: + strings: list = [ + "// Proto schema formatted by Wire, do not edit.\n", "// Source: ", + str(self.location.with_path_only()), "\n" + ] + if self.syntax: + strings.append("\n") + strings.append("syntax = \"") + strings.append(str(self.syntax)) + strings.append("\";\n") + + if self.package_name: + strings.append("\n") + strings.append("package " + str(self.package_name) + ";\n") + + if self.imports or self.public_imports: + strings.append("\n") + + for file in self.imports: + strings.append("import \"" + str(file) + "\";\n") + + for file in self.public_imports: + strings.append("import public \"" + str(file) + "\";\n") + + if self.options: + strings.append("\n") + for option in self.options: + strings.append(str(option.to_schema_declaration())) + + if self.types: + for type_element in self.types: + strings.append("\n") + strings.append(str(type_element.to_schema())) + + if self.extend_declarations: + for extend_declaration in self.extend_declarations: + strings.append("\n") + strings.append(str(extend_declaration.to_schema())) + + if self.services: + for service in self.services: + strings.append("\n") + strings.append(str(service.to_schema())) + + return "".join(strings) + + @staticmethod + def empty(path) -> 'ProtoFileElement': + return ProtoFileElement(Location.get(path)) + + # TODO: there maybe be faster comparison workaround + def __eq__(self, other: 'ProtoFileElement') -> bool: # type: ignore + a = self.to_schema() + b = other.to_schema() + + return a == b + + def __repr__(self) -> str: + return self.to_schema() + + def compare(self, other: 'ProtoFileElement', result: CompareResult) -> CompareResult: + + if self.package_name != other.package_name: + result.add_modification(Modification.PACKAGE_ALTER) + # TODO: do we need syntax check? + if self.syntax != other.syntax: + result.add_modification(Modification.SYNTAX_ALTER) + + self_types = {} + other_types = {} + self_indexes = {} + other_indexes = {} + compare_types = CompareTypes(self.package_name, other.package_name, result) + type_: TypeElement + for i, type_ in enumerate(self.types): + self_types[type_.name] = type_ + self_indexes[type_.name] = i + package_name = self.package_name or '' + compare_types.add_self_type(package_name, type_) + + for i, type_ in enumerate(other.types): + other_types[type_.name] = type_ + other_indexes[type_.name] = i + package_name = other.package_name or '' + compare_types.add_other_type(package_name, type_) + + for name in chain(self_types.keys(), other_types.keys() - self_types.keys()): + + result.push_path(str(name), True) + + if self_types.get(name) is None and other_types.get(name) is not None: + if isinstance(other_types[name], MessageElement): + result.add_modification(Modification.MESSAGE_ADD) + elif isinstance(other_types[name], EnumElement): + result.add_modification(Modification.ENUM_ADD) + else: + raise IllegalStateException("Instance of element is not applicable") + elif self_types.get(name) is not None and other_types.get(name) is None: + if isinstance(self_types[name], MessageElement): + result.add_modification(Modification.MESSAGE_DROP) + elif isinstance(self_types[name], EnumElement): + result.add_modification(Modification.ENUM_DROP) + else: + raise IllegalStateException("Instance of element is not applicable") + else: + if other_indexes[name] != self_indexes[name]: + if isinstance(self_types[name], MessageElement): + # incompatible type + result.add_modification(Modification.MESSAGE_MOVE) + else: + raise IllegalStateException("Instance of element is not applicable") + else: + if isinstance(self_types[name], MessageElement) \ + and isinstance(other_types[name], MessageElement): + self_types[name].compare(other_types[name], result, compare_types) + elif isinstance(self_types[name], EnumElement) \ + and isinstance(other_types[name], EnumElement): + self_types[name].compare(other_types[name], result, compare_types) + else: + # incompatible type + result.add_modification(Modification.TYPE_ALTER) + result.pop_path(True) + + return result diff --git a/karapace/protobuf/proto_parser.py b/karapace/protobuf/proto_parser.py new file mode 100644 index 000000000..68f73aba8 --- /dev/null +++ b/karapace/protobuf/proto_parser.py @@ -0,0 +1,589 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ProtoParser.kt + +from builtins import str +from enum import Enum +from karapace.protobuf.enum_constant_element import EnumConstantElement +from karapace.protobuf.enum_element import EnumElement +from karapace.protobuf.exception import IllegalArgumentException, SchemaParseException +from karapace.protobuf.extend_element import ExtendElement +from karapace.protobuf.extensions_element import ExtensionsElement +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.group_element import GroupElement +from karapace.protobuf.kotlin_wrapper import KotlinRange +from karapace.protobuf.location import Location +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.one_of_element import OneOfElement +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.option_reader import OptionReader +from karapace.protobuf.proto_file_element import ProtoFileElement +from karapace.protobuf.reserved_element import ReservedElement +from karapace.protobuf.rpc_element import RpcElement +from karapace.protobuf.service_element import ServiceElement +from karapace.protobuf.syntax import Syntax +from karapace.protobuf.syntax_reader import SyntaxReader +from karapace.protobuf.type_element import TypeElement +from karapace.protobuf.utils import MAX_TAG_VALUE +from typing import List, Union + + +class Context(Enum): + FILE = 1 + MESSAGE = 2 + ENUM = 3 + RPC = 4 + EXTEND = 5 + SERVICE = 6 + + def permits_package(self) -> bool: + return self == Context.FILE + + def permits_syntax(self) -> bool: + return self == Context.FILE + + def permits_import(self) -> bool: + return self == Context.FILE + + def permits_extensions(self) -> bool: + return self == Context.MESSAGE + + def permits_rpc(self) -> bool: + return self == Context.SERVICE + + def permits_one_of(self) -> bool: + return self == Context.MESSAGE + + def permits_message(self) -> bool: + return self in [Context.FILE, Context.MESSAGE] + + def permits_service(self) -> bool: + return self in [Context.FILE] + + def permits_enum(self) -> bool: + return self in [Context.FILE, Context.MESSAGE] + + def permits_extend(self) -> bool: + return self in [Context.FILE, Context.MESSAGE] + + +class ProtoParser: + def __init__(self, location: Location, data: str) -> None: + self.location = location + self.imports: List[str] = [] + self.nested_types: List[str] = [] + self.services: List[str] = [] + self.extends_list: List[str] = [] + self.options: List[str] = [] + self.declaration_count = 0 + self.syntax: Union[Syntax, None] = None + self.package_name: Union[str, None] = None + self.prefix = "" + self.data = data + self.public_imports: List[str] = [] + self.reader = SyntaxReader(data, location) + + def read_proto_file(self) -> ProtoFileElement: + while True: + documentation = self.reader.read_documentation() + if self.reader.exhausted(): + return ProtoFileElement( + self.location, self.package_name, self.syntax, self.imports, self.public_imports, self.nested_types, + self.services, self.extends_list, self.options + ) + declaration = self.read_declaration(documentation, Context.FILE) + if isinstance(declaration, TypeElement): + # TODO: add check for exception? + duplicate = next((x for x in iter(self.nested_types) if x.name == declaration.name), None) + if duplicate: + raise SchemaParseException( + f"{declaration.name} ({declaration.location}) is already defined at {duplicate.location}" + ) + self.nested_types.append(declaration) + + elif isinstance(declaration, ServiceElement): + duplicate = next((x for x in iter(self.services) if x.name == declaration.name), None) + if duplicate: + raise SchemaParseException( + f"{declaration.name} ({declaration.location}) is already defined at {duplicate.location}" + ) + self.services.append(declaration) + + elif isinstance(declaration, OptionElement): + self.options.append(declaration) + + elif isinstance(declaration, ExtendElement): + self.extends_list.append(declaration) + + def read_declaration( + self, documentation: str, context: Context + ) -> Union[None, OptionElement, ReservedElement, RpcElement, MessageElement, EnumElement, EnumConstantElement, + ServiceElement, ExtendElement, ExtensionsElement, OneOfElement, GroupElement, FieldElement]: + index = self.declaration_count + self.declaration_count += 1 + + # Skip unnecessary semicolons, occasionally used after a nested message declaration. + if self.reader.peek_char(';'): + return None + + location = self.reader.location() + label = self.reader.read_word() + + # TODO(benoit) Let's better parse the proto keywords. We are pretty weak when field/constants + # are named after any of the label we check here. + + result: Union[None, OptionElement, ReservedElement, RpcElement, MessageElement, EnumElement, EnumConstantElement, + ServiceElement, ExtendElement, ExtensionsElement, OneOfElement, GroupElement, FieldElement] = None + # pylint no-else-return + if label == "package" and context.permits_package(): + self.package_name = self.reader.read_name() + self.prefix = f"{self.package_name}." + self.reader.require(';') + elif label == "import" and context.permits_import(): + import_string = self.reader.read_string() + if import_string == "public": + self.public_imports.append(self.reader.read_string()) + + else: + self.imports.append(import_string) + self.reader.require(';') + elif label == "syntax" and context.permits_syntax(): + self.reader.expect_with_location(not self.syntax, location, "too many syntax definitions") + self.reader.require("=") + self.reader.expect_with_location( + index == 0, location, "'syntax' element must be the first declaration in a file" + ) + + syntax_string = self.reader.read_quoted_string() + try: + self.syntax = Syntax(syntax_string) + except IllegalArgumentException as e: + self.reader.unexpected(str(e), location) + self.reader.require(";") + result = None + elif label == "option": + result = OptionReader(self.reader).read_option("=") + self.reader.require(";") + elif label == "reserved": + result = self.read_reserved(location, documentation) + elif label == "message" and context.permits_message(): + result = self.read_message(location, documentation) + elif label == "enum" and context.permits_enum(): + result = self.read_enum_element(location, documentation) + elif label == "service" and context.permits_service(): + result = self.read_service(location, documentation) + elif label == "extend" and context.permits_extend(): + result = self.read_extend(location, documentation) + elif label == "rpc" and context.permits_rpc(): + result = self.read_rpc(location, documentation) + elif label == "oneof" and context.permits_one_of(): + result = self.read_one_of(documentation) + elif label == "extensions" and context.permits_extensions(): + result = self.read_extensions(location, documentation) + elif context in [Context.MESSAGE, Context.EXTEND]: + result = self.read_field(documentation, location, label) + elif context == Context.ENUM: + result = self.read_enum_constant(documentation, location, label) + else: + self.reader.unexpected(f"unexpected label: {label}", location) + return result + + def read_message(self, location: Location, documentation: str) -> MessageElement: + """ Reads a message declaration. """ + name: str = self.reader.read_name() + fields: List[FieldElement] = [] + one_ofs: List[OneOfElement] = [] + nested_types: List[TypeElement] = [] + extensions: List[ExtensionsElement] = [] + options: List[OptionElement] = [] + reserveds: List[ReservedElement] = [] + groups: List[GroupElement] = [] + + previous_prefix = self.prefix + self.prefix = f"{self.prefix}{name}." + + self.reader.require("{") + while True: + nested_documentation = self.reader.read_documentation() + if self.reader.peek_char("}"): + break + declared = self.read_declaration(nested_documentation, Context.MESSAGE) + + if isinstance(declared, FieldElement): + fields.append(declared) + elif isinstance(declared, OneOfElement): + one_ofs.append(declared) + elif isinstance(declared, GroupElement): + groups.append(declared) + elif isinstance(declared, TypeElement): + nested_types.append(declared) + elif isinstance(declared, ExtensionsElement): + extensions.append(declared) + elif isinstance(declared, OptionElement): + options.append(declared) + # Extend declarations always add in a global scope regardless of nesting. + elif isinstance(declared, ExtendElement): + self.extends_list.append(declared) + elif isinstance(declared, ReservedElement): + reserveds.append(declared) + + self.prefix = previous_prefix + + return MessageElement( + location, + name, + documentation, + nested_types, + options, + reserveds, + fields, + one_ofs, + extensions, + groups, + ) + + def read_extend(self, location: Location, documentation: str) -> ExtendElement: + """ Reads an extend declaration. """ + name = self.reader.read_name() + fields = [] + self.reader.require("{") + while True: + nested_documentation = self.reader.read_documentation() + if self.reader.peek_char("}"): + break + + declared = self.read_declaration(nested_documentation, Context.EXTEND) + if isinstance(declared, FieldElement): + fields.append(declared) + # TODO: add else clause to catch unexpected declarations. + else: + pass + + return ExtendElement( + location, + name, + documentation, + fields, + ) + + def read_service(self, location: Location, documentation: str) -> ServiceElement: + """ Reads a service declaration and returns it. """ + name = self.reader.read_name() + rpcs = [] + options = [] + self.reader.require('{') + while True: + rpc_documentation = self.reader.read_documentation() + if self.reader.peek_char("}"): + break + declared = self.read_declaration(rpc_documentation, Context.SERVICE) + if isinstance(declared, RpcElement): + rpcs.append(declared) + elif isinstance(declared, OptionElement): + options.append(declared) + # TODO: add else clause to catch unexpected declarations. + else: + pass + + return ServiceElement( + location, + name, + documentation, + rpcs, + options, + ) + + def read_enum_element(self, location: Location, documentation: str) -> EnumElement: + """ Reads an enumerated atype declaration and returns it. """ + name = self.reader.read_name() + constants = [] + options = [] + self.reader.require("{") + while True: + value_documentation = self.reader.read_documentation() + if self.reader.peek_char("}"): + break + declared = self.read_declaration(value_documentation, Context.ENUM) + + if isinstance(declared, EnumConstantElement): + constants.append(declared) + elif isinstance(declared, OptionElement): + options.append(declared) + # TODO: add else clause to catch unexpected declarations. + else: + pass + return EnumElement(location, name, documentation, options, constants) + + def read_field(self, documentation: str, location: Location, word: str) -> Union[GroupElement, FieldElement]: + label: Union[None, Field.Label] + atype: str + if word == "required": + self.reader.expect_with_location( + self.syntax != Syntax.PROTO_3, location, "'required' label forbidden in proto3 field declarations" + ) + label = Field.Label.REQUIRED + atype = self.reader.read_data_type() + elif word == "optional": + label = Field.Label.OPTIONAL + + atype = self.reader.read_data_type() + + elif word == "repeated": + label = Field.Label.REPEATED + atype = self.reader.read_data_type() + else: + self.reader.expect_with_location( + self.syntax == Syntax.PROTO_3 or (word == "map" and self.reader.peek_char() == "<"), location, + f"unexpected label: {word}" + ) + + label = None + atype = self.reader.read_data_type_by_name(word) + + self.reader.expect_with_location(not atype.startswith("map<") or not label, location, "'map' type cannot have label") + if atype == "group": + return self.read_group(location, documentation, label) + return self.read_field_with_label(location, documentation, label, atype) + + def read_field_with_label( + self, location: Location, documentation: str, label: Union[None, Field.Label], atype: str + ) -> FieldElement: + """ Reads an field declaration and returns it. """ + name = self.reader.read_name() + self.reader.require('=') + tag = self.reader.read_int() + + # Mutable copy to extract the default value, and add packed if necessary. + options = OptionReader(self.reader).read_options() + + default_value = self.strip_default(options) + json_name = self.strip_json_name(options) + self.reader.require(';') + + documentation = self.reader.try_append_trailing_documentation(documentation) + + return FieldElement( + location, + label, + atype, + name, + default_value, + json_name, + tag, + documentation, + options, + ) + + def strip_default(self, options: list) -> Union[str, None]: + """ Defaults aren't options. """ + return self.strip_value("default", options) + + def strip_json_name(self, options: list) -> Union[None, str]: + """ `json_name` isn't an option. """ + return self.strip_value("json_name", options) + + @staticmethod + def strip_value(name: str, options: list) -> Union[None, str]: + """ This finds an option named [name], removes, and returns it. + Returns None if no [name] option is present. + """ + result: Union[None, str] = None + for element in options[:]: + if element.name == name: + options.remove(element) + result = str(element.value) + return result + + def read_one_of(self, documentation: str) -> OneOfElement: + name: str = self.reader.read_name() + fields = [] + groups = [] + options = [] + + self.reader.require("{") + while True: + nested_documentation = self.reader.read_documentation() + if self.reader.peek_char("}"): + break + + location = self.reader.location() + atype = self.reader.read_data_type() + if atype == "group": + groups.append(self.read_group(location, nested_documentation, None)) + elif atype == "option": + options.append(OptionReader(self.reader).read_option("=")) + self.reader.require(";") + else: + fields.append(self.read_field_with_label(location, nested_documentation, None, atype)) + + return OneOfElement( + name, + documentation, + fields, + groups, + options, + ) + + def read_group( + self, + location: Location, + documentation: str, + label: Union[None, Field.Label], + ) -> GroupElement: + name = self.reader.read_word() + self.reader.require("=") + tag = self.reader.read_int() + fields = [] + self.reader.require("{") + + while True: + nested_documentation = self.reader.read_documentation() + if self.reader.peek_char("}"): + break + + field_location = self.reader.location() + field_label = self.reader.read_word() + field = self.read_field(nested_documentation, field_location, field_label) + if isinstance(field, FieldElement): + fields.append(field) + else: + self.reader.unexpected(f"expected field declaration, was {field}") + + return GroupElement(label, location, name, tag, documentation, fields) + + def read_reserved(self, location: Location, documentation: str) -> ReservedElement: + """ Reads a reserved tags and names list like "reserved 10, 12 to 14, 'foo';". """ + values = [] + while True: + ch = self.reader.peek_char() + if ch in ["\"", "'"]: + values.append(self.reader.read_quoted_string()) + else: + tag_start = self.reader.read_int() + ch = self.reader.peek_char() + if ch in [",", ";"]: + values.append(tag_start) + else: + self.reader.expect_with_location(self.reader.read_word() == "to", location, "expected ',', ';', or 'to'") + tag_end = self.reader.read_int() + values.append(KotlinRange(tag_start, tag_end)) + + ch = self.reader.read_char() + # pylint: disable=no-else-break + if ch == ";": + break + elif ch == ",": + continue + else: + self.reader.unexpected("expected ',' or ';'") + a = False + if values: + a = True + + self.reader.expect_with_location(a, location, "'reserved' must have at least one field name or tag") + my_documentation = self.reader.try_append_trailing_documentation(documentation) + + return ReservedElement(location, my_documentation, values) + + def read_extensions(self, location: Location, documentation: str) -> ExtensionsElement: + """ Reads extensions like "extensions 101;" or "extensions 101 to max;". """ + values = [] + while True: + start: int = self.reader.read_int() + ch = self.reader.peek_char() + end: int + if ch in [",", ";"]: + values.append(start) + else: + self.reader.expect_with_location(self.reader.read_word() == "to", location, "expected ',', ';' or 'to'") + s = self.reader.read_word() + if s == "max": + end = MAX_TAG_VALUE + else: + end = int(s) + values.append(KotlinRange(start, end)) + + ch = self.reader.read_char() + # pylint: disable=no-else-break + if ch == ";": + break + elif ch == ",": + continue + else: + self.reader.unexpected("expected ',' or ';'") + + return ExtensionsElement(location, documentation, values) + + def read_enum_constant(self, documentation: str, location: Location, label: str) -> EnumConstantElement: + """ Reads an enum constant like "ROCK = 0;". The label is the constant name. """ + self.reader.require('=') + tag = self.reader.read_int() + + options: list = OptionReader(self.reader).read_options() + self.reader.require(';') + + documentation = self.reader.try_append_trailing_documentation(documentation) + + return EnumConstantElement( + location, + label, + tag, + documentation, + options, + ) + + def read_rpc(self, location: Location, documentation: str) -> RpcElement: + """ Reads an rpc and returns it. """ + name = self.reader.read_name() + + self.reader.require('(') + request_streaming = False + + word = self.reader.read_word() + if word == "stream": + request_streaming = True + request_type = self.reader.read_data_type() + else: + request_type = self.reader.read_data_type_by_name(word) + + self.reader.require(')') + + self.reader.expect_with_location(self.reader.read_word() == "returns", location, "expected 'returns'") + + self.reader.require('(') + response_streaming = False + + word = self.reader.read_word() + if word == "stream": + response_streaming = True + response_type = self.reader.read_data_type() + else: + response_type = self.reader.read_data_type_by_name(word) + + self.reader.require(')') + + options = [] + if self.reader.peek_char('{'): + while True: + rpc_documentation = self.reader.read_documentation() + if self.reader.peek_char('}'): + break + declared = self.read_declaration(rpc_documentation, Context.RPC) + if isinstance(declared, OptionElement): + options.append(declared) + # TODO: add else clause to catch unexpected declarations. + else: + pass + + else: + self.reader.require(';') + + return RpcElement( + location, name, documentation, request_type, response_type, request_streaming, response_streaming, options + ) + + @staticmethod + def parse(location: Location, data: str) -> ProtoFileElement: + """ Parse a named `.proto` schema. """ + proto_parser = ProtoParser(location, data) + return proto_parser.read_proto_file() diff --git a/karapace/protobuf/proto_type.py b/karapace/protobuf/proto_type.py new file mode 100644 index 000000000..74ad913d5 --- /dev/null +++ b/karapace/protobuf/proto_type.py @@ -0,0 +1,221 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/ProtoType.kt +""" +Names a protocol buffer message, enumerated type, service, map, or a scalar. This class models a +fully-qualified name using the protocol buffer package. +""" +from enum import auto, Enum +from karapace.protobuf.exception import IllegalArgumentException, IllegalStateException +from karapace.protobuf.option_element import OptionElement +from typing import Optional + + +def static_init(cls) -> object: + if getattr(cls, "static_init", None): + cls.static_init() + return cls + + +@static_init +class ProtoType: + @property + def simple_name(self) -> str: + dot = self.string.rfind(".") + return self.string[dot + 1:] + + @classmethod + def static_init(cls) -> None: + cls.BOOL = cls(True, "bool") + cls.BYTES = cls(True, "bytes") + cls.DOUBLE = cls(True, "double") + cls.FLOAT = cls(True, "float") + cls.FIXED32 = cls(True, "fixed32") + cls.FIXED64 = cls(True, "fixed64") + cls.INT32 = cls(True, "int32") + cls.INT64 = cls(True, "int64") + cls.SFIXED32 = cls(True, "sfixed32") + cls.SFIXED64 = cls(True, "sfixed64") + cls.SINT32 = cls(True, "sint32") + cls.SINT64 = cls(True, "sint64") + cls.STRING = cls(True, "string") + cls.UINT32 = cls(True, "uint32") + cls.UINT64 = cls(True, "uint64") + cls.ANY = cls(False, "google.protobuf.Any") + cls.DURATION = cls(False, "google.protobuf.Duration") + cls.TIMESTAMP = cls(False, "google.protobuf.Timestamp") + cls.EMPTY = cls(False, "google.protobuf.Empty") + cls.STRUCT_MAP = cls(False, "google.protobuf.Struct") + cls.STRUCT_VALUE = cls(False, "google.protobuf.Value") + cls.STRUCT_NULL = cls(False, "google.protobuf.NullValue") + cls.STRUCT_LIST = cls(False, "google.protobuf.ListValue") + cls.DOUBLE_VALUE = cls(False, "google.protobuf.DoubleValue") + cls.FLOAT_VALUE = cls(False, "google.protobuf.FloatValue") + cls.INT64_VALUE = cls(False, "google.protobuf.Int64Value") + cls.UINT64_VALUE = cls(False, "google.protobuf.UInt64Value") + cls.INT32_VALUE = cls(False, "google.protobuf.Int32Value") + cls.UINT32_VALUE = cls(False, "google.protobuf.UInt32Value") + cls.BOOL_VALUE = cls(False, "google.protobuf.BoolValue") + cls.STRING_VALUE = cls(False, "google.protobuf.StringValue") + cls.BYTES_VALUE = cls(False, "google.protobuf.BytesValue") + + cls.SCALAR_TYPES_ = [ + cls.BOOL, cls.BYTES, cls.DOUBLE, cls.FLOAT, cls.FIXED32, cls.FIXED64, cls.INT32, cls.INT64, cls.SFIXED32, + cls.SFIXED64, cls.SINT32, cls.SINT64, cls.STRING, cls.UINT32, cls.UINT64 + ] + + cls.SCALAR_TYPES = {} + + for a in cls.SCALAR_TYPES_: + cls.SCALAR_TYPES[a.string] = a + + cls.NUMERIC_SCALAR_TYPES: tuple = ( + cls.DOUBLE, cls.FLOAT, cls.FIXED32, cls.FIXED64, cls.INT32, cls.INT64, cls.SFIXED32, cls.SFIXED64, cls.SINT32, + cls.SINT64, cls.UINT32, cls.UINT64 + ) + + def __init__( + self, is_scalar: bool, string: str, key_type: Optional['ProtoType'] = None, value_type: Optional['ProtoType'] = None + ) -> None: + """ Creates a scalar or message type. """ + if not key_type and not value_type: + self.is_scalar = is_scalar + self.string = string + self.is_map = False + """ The type of the map's keys. Only present when [is_map] is True. """ + self.key_type = None + """ The type of the map's values. Only present when [is_map] is True. """ + self.value_type = None + else: + if key_type.is_scalar and key_type != self.BYTES and key_type != self.DOUBLE and key_type != self.FLOAT: + self.is_scalar = False + self.string = string + self.is_map = True + self.key_type = key_type + self.value_type = value_type + else: + raise IllegalArgumentException(f"map key must be non-byte, non-floating point scalar: {key_type}") + + def to_kind(self) -> Optional[OptionElement.Kind]: + return { + "bool": OptionElement.Kind.BOOLEAN, + "string": OptionElement.Kind.STRING, + "bytes": OptionElement.Kind.NUMBER, + "double": OptionElement.Kind.NUMBER, + "float": OptionElement.Kind.NUMBER, + "fixed32": OptionElement.Kind.NUMBER, + "fixed64": OptionElement.Kind.NUMBER, + "int32": OptionElement.Kind.NUMBER, + "int64": OptionElement.Kind.NUMBER, + "sfixed32": OptionElement.Kind.NUMBER, + "sfixed64": OptionElement.Kind.NUMBER, + "sint32": OptionElement.Kind.NUMBER, + "sint64": OptionElement.Kind.NUMBER, + "uint32": OptionElement.Kind.NUMBER, + "uint64": OptionElement.Kind.NUMBER + }.get(self.simple_name, OptionElement.Kind.ENUM) + + @property + def enclosing_type_or_package(self) -> Optional[str]: + """Returns the enclosing type, or None if self type is not nested in another type.""" + dot = self.string.rfind(".") + return None if (dot == -1) else self.string[:dot] + + @property + def type_url(self) -> Optional[str]: + """Returns a string like "type.googleapis.com/packagename.messagename" or None if self type is + a scalar or a map. + + Note:: Returns a string for enums because it doesn't know + if the named type is a message or an enum. + """ + return None if self.is_scalar or self.is_map else f"type.googleapis.com/{self.string}" + + def nested_type(self, name: str) -> 'ProtoType': + + if self.is_scalar: + raise IllegalStateException("scalar cannot have a nested type") + + if self.is_map: + raise IllegalStateException("map cannot have a nested type") + + if not (name and name.rfind(".") == -1 and len(name) != 0): + raise IllegalArgumentException(f"unexpected name: {name}") + + return ProtoType(False, f"{self.string}.{name}") + + def __eq__(self, other) -> bool: + return isinstance(other, ProtoType) and self.string == other.string + + def __ne__(self, other) -> bool: + return not isinstance(other, ProtoType) or self.string != other.string + + def __str__(self) -> str: + return self.string + + def hash_code(self) -> int: + return hash(self.string) + + @staticmethod + def get(enclosing_type_or_package: str, type_name: str) -> 'ProtoType': + return ProtoType.get2(f"{enclosing_type_or_package}.{type_name}") \ + if enclosing_type_or_package else ProtoType.get2(type_name) + + @staticmethod + def get2(name: str) -> 'ProtoType': + scalar = ProtoType.SCALAR_TYPES.get(name) + if scalar: + return scalar + if not (name and len(name) != 0 and name.rfind("#") == -1): + raise IllegalArgumentException(f"unexpected name: {name}") + + if name.startswith("map<") and name.endswith(">"): + comma = name.rfind(",") + if not comma != -1: + raise IllegalArgumentException(f"expected ',' in map type: {name}") + key = ProtoType.get2(name[4:comma].strip()) + value = ProtoType.get2(name[comma + 1:len(name) - 1].strip()) + return ProtoType(False, name, key, value) + return ProtoType(False, name) + + @staticmethod + def get3(key_type: 'ProtoType', value_type: 'ProtoType', name: str) -> 'ProtoType': + return ProtoType(False, name, key_type, value_type) + + # schema compatibility check functionality karapace addon + # Based on table https://developers.google.com/protocol-buffers/docs/proto3#scalar + + class CompatibilityKind(Enum): + VARIANT = auto() + SVARIANT = auto() # sint has incompatible format with int but compatible with it by size + FIXED64 = auto() + LENGTH_DELIMITED = auto() + FIXED32 = auto() + DOUBLE = auto() + FLOAT = auto() + + def compatibility_kind(self, is_enum: bool) -> 'ProtoType.CompatibilityKind': + if is_enum: + return ProtoType.CompatibilityKind.VARIANT + + result = { + "int32": ProtoType.CompatibilityKind.VARIANT, + "int64": ProtoType.CompatibilityKind.VARIANT, + "uint32": ProtoType.CompatibilityKind.VARIANT, + "uint64": ProtoType.CompatibilityKind.VARIANT, + "bool": ProtoType.CompatibilityKind.VARIANT, + "sint32": ProtoType.CompatibilityKind.SVARIANT, + "sint64": ProtoType.CompatibilityKind.SVARIANT, + "double": ProtoType.CompatibilityKind.DOUBLE, # it is compatible by size with FIXED64 + "fixed64": ProtoType.CompatibilityKind.FIXED64, + "sfixed64": ProtoType.CompatibilityKind.FIXED64, + "float": ProtoType.CompatibilityKind.FLOAT, # it is compatible by size with FIXED32 + "fixed32": ProtoType.CompatibilityKind.FIXED32, + "sfixed32": ProtoType.CompatibilityKind.FIXED32, + "string": ProtoType.CompatibilityKind.LENGTH_DELIMITED, + "bytes": ProtoType.CompatibilityKind.LENGTH_DELIMITED, + }.get(self.simple_name) + + if result: + return result + + raise IllegalArgumentException(f"undefined type: {self.simple_name}") diff --git a/karapace/protobuf/protobuf_to_dict.py b/karapace/protobuf/protobuf_to_dict.py new file mode 100644 index 000000000..2fd0d372d --- /dev/null +++ b/karapace/protobuf/protobuf_to_dict.py @@ -0,0 +1,335 @@ +""" +This module provide a small Python library for creating dicts from protocol buffers +Module based on code : +https://github.com/wearefair/protobuf-to-dict +LICENSE: https://github.com/wearefair/protobuf-to-dict/blob/master/LICENSE +""" +from dateutil.parser import parse as date_parser +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message +from google.protobuf.timestamp_pb2 import Timestamp +from types import MappingProxyType + +import datetime + +__all__ = ["protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf", "REVERSE_TYPE_CALLABLE_MAP"] + +Timestamp_type_name = 'Timestamp' + +# pylint: disable=no-member + + +def datetime_to_timestamp(dt): + ts = Timestamp() + + ts.FromDatetime(dt) + return ts + + +def timestamp_to_datetime(ts): + dt = ts.ToDatetime() + return dt + + +# pylint: enable=no-member + +EXTENSION_CONTAINER = '___X' + +TYPE_CALLABLE_MAP = MappingProxyType({ + FieldDescriptor.TYPE_DOUBLE: float, + FieldDescriptor.TYPE_FLOAT: float, + FieldDescriptor.TYPE_INT32: int, + FieldDescriptor.TYPE_INT64: int, + FieldDescriptor.TYPE_UINT32: int, + FieldDescriptor.TYPE_UINT64: int, + FieldDescriptor.TYPE_SINT32: int, + FieldDescriptor.TYPE_SINT64: int, + FieldDescriptor.TYPE_FIXED32: int, + FieldDescriptor.TYPE_FIXED64: int, + FieldDescriptor.TYPE_SFIXED32: int, + FieldDescriptor.TYPE_SFIXED64: int, + FieldDescriptor.TYPE_BOOL: bool, + FieldDescriptor.TYPE_STRING: str, + FieldDescriptor.TYPE_BYTES: bytes, + FieldDescriptor.TYPE_ENUM: int, +}) + + +def repeated(type_callable): + return lambda value_list: [type_callable(value) for value in value_list] + + +def enum_label_name(field, value, lowercase_enum_lables=False) -> str: + label = field.enum_type.values_by_number[int(value)].name + label = label.lower() if lowercase_enum_lables else label + return label + + +def _is_map_entry(field) -> bool: + return ( + field.type == FieldDescriptor.TYPE_MESSAGE and field.message_type.has_options + and field.message_type.GetOptions().map_entry + ) + + +def protobuf_to_dict(pb, use_enum_labels=True, including_default_value_fields=True, lowercase_enum_lables=False) -> dict: + type_callable_map = TYPE_CALLABLE_MAP + result_dict = {} + extensions = {} + for field, value in pb.ListFields(): + if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry: + result_dict[field.name] = {} + value_field = field.message_type.fields_by_name['value'] + type_callable = _get_field_value_adaptor( + pb, value_field, type_callable_map, use_enum_labels, including_default_value_fields, lowercase_enum_lables + ) + for k, v in value.items(): + result_dict[field.name][k] = type_callable(v) + continue + type_callable = _get_field_value_adaptor( + pb, field, type_callable_map, use_enum_labels, including_default_value_fields, lowercase_enum_lables + ) + if field.label == FieldDescriptor.LABEL_REPEATED: + type_callable = repeated(type_callable) + + if field.is_extension: + extensions[str(field.number)] = type_callable(value) + continue + + result_dict[field.name] = type_callable(value) + + # Serialize default value if including_default_value_fields is True. + if including_default_value_fields: + for field in pb.DESCRIPTOR.fields: + # Singular message fields and oneof fields will not be affected. + if field.label != FieldDescriptor.LABEL_REPEATED and field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + continue + if field.containing_oneof: + continue + if field.name in result_dict: + # Skip the field which has been serailized already. + continue + if _is_map_entry(field): + result_dict[field.name] = {} + elif field.label == FieldDescriptor.LABEL_REPEATED: + result_dict[field.name] = [] + elif field.type == FieldDescriptor.TYPE_ENUM and use_enum_labels: + result_dict[field.name] = enum_label_name(field, field.default_value, lowercase_enum_lables) + else: + result_dict[field.name] = field.default_value + + if extensions: + result_dict[EXTENSION_CONTAINER] = extensions + return result_dict + + +def _get_field_value_adaptor( + pb, + field, + type_callable_map=TYPE_CALLABLE_MAP, + use_enum_labels=False, + including_default_value_fields=False, + lowercase_enum_lables=False +): + if field.message_type and field.message_type.name == Timestamp_type_name: + return timestamp_to_datetime + if field.type == FieldDescriptor.TYPE_MESSAGE: + # recursively encode protobuf sub-message + return lambda pb: protobuf_to_dict( + pb, + use_enum_labels=use_enum_labels, + including_default_value_fields=including_default_value_fields, + lowercase_enum_lables=lowercase_enum_lables, + ) + + if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM: + return lambda value: enum_label_name(field, value, lowercase_enum_lables) + + if field.type in type_callable_map: + return type_callable_map[field.type] + + raise TypeError("Field %s.%s has unrecognised type id %d" % (pb.__class__.__name__, field.name, field.type)) + + +REVERSE_TYPE_CALLABLE_MAP = MappingProxyType({}) + + +def dict_to_protobuf( + pb_klass_or_instance, + values, + type_callable_map=REVERSE_TYPE_CALLABLE_MAP, + strict=True, + ignore_none=False, + use_date_parser_for_fields=None +) -> object: + """Populates a protobuf model from a dictionary. + + :param ignore_none: + :param pb_klass_or_instance: a protobuf message class, or an protobuf instance + :type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message + :param dict values: a dictionary of values. Repeated and nested values are + fully supported. + :param dict type_callable_map: a mapping of protobuf types to callables for setting + values on the target instance. + :param bool strict: complain if keys in the map are not fields on the message. + :param bool strict: ignore None-values of fields, treat them as empty field + :param bool strict: when false: accept enums both in lowercase and uppercase + :param list use_date_parser_for_fields: a list of fields that need to use date_parser + """ + if isinstance(pb_klass_or_instance, Message): + instance = pb_klass_or_instance + else: + instance = pb_klass_or_instance() + return _dict_to_protobuf(instance, values, type_callable_map, strict, ignore_none, use_date_parser_for_fields) + + +def _get_field_mapping(pb, dict_value, strict): + field_mapping = [] + key: str = "" + for key, value in dict_value.items(): + if key == EXTENSION_CONTAINER: + continue + if key not in pb.DESCRIPTOR.fields_by_name: + if strict: + raise KeyError("%s does not have a field called %s" % (type(pb), key)) + continue + field_mapping.append((pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None))) + + for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items(): + try: + ext_num = int(ext_num) + except ValueError: + raise ValueError("Extension keys must be integers.") + # pylint: disable=protected-access + if ext_num not in pb._extensions_by_number: + if strict: + raise KeyError("%s does not have a extension with number %s. Perhaps you forgot to import it?" % (pb, key)) + continue + # pylint: disable=protected-access + + ext_field = pb._extensions_by_number[ext_num] + # noinspection PyUnusedLocal + pb_val = None + pb_val = pb.Extensions[ext_field] + field_mapping.append((ext_field, ext_val, pb_val)) + + return field_mapping + + +def _dict_to_protobuf(pb, value_, type_callable_map, strict, ignore_none, use_date_parser_for_fields): + fields = _get_field_mapping(pb, value_, strict) + value = value_ + for field, input_value, pb_value in fields: + if ignore_none and input_value is None: + continue + + if field.label == FieldDescriptor.LABEL_REPEATED: + if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry: + key_field = field.message_type.fields_by_name['key'] + value_field = field.message_type.fields_by_name['value'] + for key, value in input_value.items(): + if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _dict_to_protobuf( + getattr(pb, field.name)[key], value, type_callable_map, strict, ignore_none, + use_date_parser_for_fields + ) + else: + if ignore_none and value is None: + continue + try: + if key_field.type in type_callable_map: + key = type_callable_map[key_field.type](key) + if value_field.type in type_callable_map: + value = type_callable_map[value_field.type](value) + getattr(pb, field.name)[key] = value + except Exception as exc: + raise RuntimeError(f"type: {type(pb)}, field: {field.name}, value: {value}") from exc + continue + for item in input_value: + if field.type == FieldDescriptor.TYPE_MESSAGE: + m = pb_value.add() + _dict_to_protobuf(m, item, type_callable_map, strict, ignore_none, use_date_parser_for_fields) + elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(item, str): + pb_value.append(_string_to_enum(field, item, strict)) + else: + pb_value.append(item) + continue + + if isinstance(input_value, datetime.datetime): + input_value = datetime_to_timestamp(input_value) + # Instead of setattr we need to use CopyFrom for composite fields + # Otherwise we will get AttributeError: + # Assignment not allowed to composite field “field name” in protocol message object + getattr(pb, field.name).CopyFrom(input_value) + continue + + if use_date_parser_for_fields and field.name in use_date_parser_for_fields: + input_value = datetime_to_timestamp(date_parser(input_value)) + getattr(pb, field.name).CopyFrom(input_value) + continue + + if field.type == FieldDescriptor.TYPE_MESSAGE: + _dict_to_protobuf(pb_value, input_value, type_callable_map, strict, ignore_none, use_date_parser_for_fields) + continue + + if field.type in type_callable_map: + input_value = type_callable_map[field.type](input_value) + + if field.is_extension: + pb.Extensions[field] = input_value + continue + + if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value, str): + input_value = _string_to_enum(field, input_value, strict) + + try: + setattr(pb, field.name, input_value) + except Exception as exc: + raise RuntimeError(f"type: {type(pb)}, field: {field.name}, value: {value}") from exc + + return pb + + +def _string_to_enum(field, input_value, strict=False): + try: + input_value = field.enum_type.values_by_name[input_value].number + except KeyError: + if strict: + raise KeyError("`%s` is not a valid value for field `%s`" % (input_value, field.name)) + return _string_to_enum(field, input_value.upper(), strict=True) + return input_value + + +def get_field_names_and_options(pb): + """ + Return a tuple of field names and options. + """ + desc = pb.DESCRIPTOR + + for field in desc.fields: + field_name = field.name + options_dict = {} + if field.has_options: + options = field.GetOptions() + for subfield, value in options.ListFields(): + options_dict[subfield.name] = value + yield field, field_name, options_dict + + +class FieldsMissing(ValueError): + pass + + +def validate_dict_for_required_pb_fields(pb, dic): + """ + Validate that the dictionary has all the required fields for creating a protobuffer object + from pb class. If a field is missing, raise FieldsMissing. + In order to mark a field as optional, add [(is_optional) = true] to the field. + Take a look at the tests for an example. + """ + missing_fields = [] + for field, field_name, field_options in get_field_names_and_options(pb): + if not field_options.get('is_optional', False) and field_name not in dic: + missing_fields.append(field_name) + if missing_fields: + raise FieldsMissing('Missing fields: {}'.format(', '.join(missing_fields))) diff --git a/karapace/protobuf/reserved_element.py b/karapace/protobuf/reserved_element.py new file mode 100644 index 000000000..ccb33e01b --- /dev/null +++ b/karapace/protobuf/reserved_element.py @@ -0,0 +1,35 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ReservedElement.kt +from dataclasses import dataclass +from karapace.protobuf.kotlin_wrapper import KotlinRange +from karapace.protobuf.location import Location +from karapace.protobuf.utils import append_documentation + + +@dataclass +class ReservedElement: + location: Location + documentation: str = "" + """ A [String] name or [Int] or [IntRange] tag. """ + values: list = None + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append("reserved ") + + for index in range(len(self.values)): + value = self.values[index] + if index > 0: + result.append(", ") + + if isinstance(value, str): + result.append(f"\"{value}\"") + elif isinstance(value, int): + result.append(f"{value}") + elif isinstance(value, KotlinRange): + result.append(f"{value.minimum} to {value.maximum}") + else: + raise AssertionError() + result.append(";\n") + return "".join(result) diff --git a/karapace/protobuf/rpc_element.py b/karapace/protobuf/rpc_element.py new file mode 100644 index 000000000..51f4ebb35 --- /dev/null +++ b/karapace/protobuf/rpc_element.py @@ -0,0 +1,48 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/RpcElement.kt + +from karapace.protobuf.location import Location +from karapace.protobuf.utils import append_documentation, append_indented + + +class RpcElement: + def __init__( + self, + location: Location, + name: str, + documentation: str = "", + request_type: str = "", + response_type: str = "", + request_streaming: bool = False, + response_streaming: bool = False, + options: list = None + ) -> None: + self.location = location + self.name = name + self.documentation = documentation + self.request_type = request_type + self.response_type = response_type + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.options = options or [] + + def to_schema(self) -> str: + result = [] + append_documentation(result, self.documentation) + result.append(f"rpc {self.name} (") + + if self.request_streaming: + result.append("stream ") + result.append(f"{self.request_type}) returns (") + + if self.response_streaming: + result.append("stream ") + result.append(f"{self.response_type})") + + if self.options: + result.append(" {\n") + for option in self.options: + append_indented(result, option.to_schema_declaration()) + result.append("}") + result.append(";\n") + return "".join(result) diff --git a/karapace/protobuf/schema.py b/karapace/protobuf/schema.py new file mode 100644 index 000000000..d539bfa84 --- /dev/null +++ b/karapace/protobuf/schema.py @@ -0,0 +1,162 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Schema.kt +# Ported partially for required functionality. +from karapace.protobuf.compare_result import CompareResult +from karapace.protobuf.enum_element import EnumElement +from karapace.protobuf.exception import IllegalArgumentException +from karapace.protobuf.location import Location +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.proto_file_element import ProtoFileElement +from karapace.protobuf.proto_parser import ProtoParser +from karapace.protobuf.utils import append_documentation, append_indented + + +def add_slashes(text: str) -> str: + escape_dict = { + '\a': '\\a', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', + '\v': '\\v', + '\'': "\\'", + '\"': '\\"', + '\\': '\\\\' + } + trans_table = str.maketrans(escape_dict) + return text.translate(trans_table) + + +def message_element_string(element: MessageElement) -> str: + result = [] + append_documentation(result, element.documentation) + result.append(f"message {element.name} {{") + if element.reserveds: + result.append("\n") + for reserved in element.reserveds: + append_indented(result, reserved.to_schema()) + + if element.options: + result.append("\n") + for option in element.options: + append_indented(result, option_element_string(option)) + + if element.fields: + result.append("\n") + for field in element.fields: + append_indented(result, field.to_schema()) + + if element.one_ofs: + result.append("\n") + for one_of in element.one_ofs: + append_indented(result, one_of.to_schema()) + + if element.groups: + result.append("\n") + for group in element.groups: + append_indented(result, group.to_schema()) + + if element.extensions: + result.append("\n") + for extension in element.extensions: + append_indented(result, extension.to_schema()) + + if element.nested_types: + result.append("\n") + for nested_type in element.nested_types: + if isinstance(nested_type, MessageElement): + append_indented(result, message_element_string(nested_type)) + + for nested_type in element.nested_types: + if isinstance(nested_type, EnumElement): + append_indented(result, enum_element_string(nested_type)) + + result.append("}\n") + return "".join(result) + + +def enum_element_string(element: EnumElement) -> str: + return element.to_schema() + + +def option_element_string(option: OptionElement) -> str: + result: str + if option.kind == OptionElement.Kind.STRING: + name: str + if option.is_parenthesized: + name = f"({option.name})" + else: + name = option.name + value = add_slashes(str(option.value)) + result = f"{name} = \"{value}\"" + else: + result = option.to_schema() + + return f"option {result};\n" + + +class ProtobufSchema: + DEFAULT_LOCATION = Location.get("") + + def __init__(self, schema: str) -> None: + if type(schema).__name__ != 'str': + raise IllegalArgumentException("Non str type of schema string") + self.dirty = schema + self.cache_string = "" + self.proto_file_element = ProtoParser.parse(self.DEFAULT_LOCATION, schema) + + def __str__(self) -> str: + if not self.cache_string: + self.cache_string = self.to_schema() + return self.cache_string + + def to_schema(self) -> str: + strings = [] + shm: ProtoFileElement = self.proto_file_element + if shm.syntax: + strings.append("syntax = \"") + strings.append(str(shm.syntax)) + strings.append("\";\n") + + if shm.package_name: + strings.append("package " + str(shm.package_name) + ";\n") + + if shm.imports or shm.public_imports: + strings.append("\n") + + for file in shm.imports: + strings.append("import \"" + str(file) + "\";\n") + + for file in shm.public_imports: + strings.append("import public \"" + str(file) + "\";\n") + + if shm.options: + strings.append("\n") + for option in shm.options: + # strings.append(str(option.to_schema_declaration())) + strings.append(option_element_string(option)) + + if shm.types: + strings.append("\n") + for type_element in shm.types: + if isinstance(type_element, MessageElement): + strings.append(message_element_string(type_element)) + for type_element in shm.types: + if isinstance(type_element, EnumElement): + strings.append(enum_element_string(type_element)) + + if shm.extend_declarations: + strings.append("\n") + for extend_declaration in shm.extend_declarations: + strings.append(str(extend_declaration.to_schema())) + + if shm.services: + strings.append("\n") + for service in shm.services: + strings.append(str(service.to_schema())) + return "".join(strings) + + def compare(self, other: 'ProtobufSchema', result: CompareResult) -> CompareResult: + self.proto_file_element.compare(other.proto_file_element, result) diff --git a/karapace/protobuf/service_element.py b/karapace/protobuf/service_element.py new file mode 100644 index 000000000..e5131a922 --- /dev/null +++ b/karapace/protobuf/service_element.py @@ -0,0 +1,33 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/ServiceElement.kt +from dataclasses import dataclass +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.rpc_element import RpcElement +from karapace.protobuf.utils import append_documentation, append_indented +from typing import List + + +@dataclass +class ServiceElement: + location: Location + name: str + documentation: str = "" + rpcs: List[RpcElement] = None + options: List[OptionElement] = None + + def to_schema(self) -> str: + result: List[str] = [] + append_documentation(result, self.documentation) + result.append(f"service {self.name} {{") + if self.options: + result.append("\n") + for option in self.options: + append_indented(result, option.to_schema_declaration()) + if self.rpcs: + result.append('\n') + for rpc in self.rpcs: + append_indented(result, rpc.to_schema()) + + result.append("}\n") + return "".join(result) diff --git a/karapace/protobuf/syntax.py b/karapace/protobuf/syntax.py new file mode 100644 index 000000000..a7d80e045 --- /dev/null +++ b/karapace/protobuf/syntax.py @@ -0,0 +1,20 @@ +# Ported from square/wire: +# wire-library/wire-runtime/src/commonMain/kotlin/com/squareup/wire/Syntax.kt + +from enum import Enum +from karapace.protobuf.exception import IllegalArgumentException + + +class Syntax(Enum): + PROTO_2 = "proto2" + PROTO_3 = "proto3" + + @classmethod + def _missing_(cls, string) -> None: + raise IllegalArgumentException(f"unexpected syntax: {string}") + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return self.value diff --git a/karapace/protobuf/syntax_reader.py b/karapace/protobuf/syntax_reader.py new file mode 100644 index 000000000..9f2fb0c4a --- /dev/null +++ b/karapace/protobuf/syntax_reader.py @@ -0,0 +1,364 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/SyntaxReader.kt +from karapace.protobuf.exception import IllegalStateException +from karapace.protobuf.location import Location +from typing import NoReturn, Union + + +class SyntaxReader: + def __init__(self, data: str, location: Location) -> None: + """ Next character to be read """ + self.pos = 0 + """ The number of newline characters """ + self.line = 0 + """ The index of the most recent newline character. """ + self.line_start = 0 + self.data = data + self._location = location + + def exhausted(self) -> bool: + return self.pos == len(self.data) + + def read_char(self) -> str: + """ Reads a non-whitespace character """ + + char = self.peek_char() + self.pos += 1 + return char + + def require(self, c: str) -> None: + """ Reads a non-whitespace character 'c' """ + self.expect(self.read_char() == c, f"expected '{c}'") + + def peek_char(self, ch: str = None) -> Union[bool, str]: + """ Peeks a non-whitespace character and returns it. The only difference between this and + [read_char] is that this doesn't consume the char. + """ + + if ch: + if self.peek_char() == ch: + self.pos += 1 + return True + return False + self.skip_whitespace(True) + self.expect(self.pos < len(self.data), "unexpected end of file") + return self.data[self.pos] + + def push_back(self, ch: str) -> None: + """ Push back the most recently read character. """ + if self.data[self.pos - 1] == ch: + self.pos -= 1 + + def read_string(self) -> str: + """ Reads a quoted or unquoted string and returns it. """ + self.skip_whitespace(True) + if self.peek_char() in ['"', "'"]: + return self.read_quoted_string() + return self.read_word() + + def read_quoted_string(self) -> str: + start_quote = self.read_char() + if start_quote not in ('"', "'"): + raise IllegalStateException(" quote expected") + + result = [] + + while self.pos < len(self.data): + c = self.data[self.pos] + self.pos += 1 + if c == start_quote: + if self.peek_char() == '"' or self.peek_char() == "'": + # Adjacent strings are concatenated. Consume new quote and continue reading. + start_quote = self.read_char() + continue + return "".join(result) + if c == "\\": + self.expect(self.pos < len(self.data), "unexpected end of file") + c = self.data[self.pos] + self.pos += 1 + d: Union[str, None] = { + 'a': "\u0007", # Alert. + 'b': "\b", # Backspace. + 'f': "\u000c", # Form feed. + 'n': "\n", # Newline. + 'r': "\r", # Carriage return. + 't': "\t", # Horizontal tab. + 'v': "\u000b", # Vertical tab. + }.get(c) + if d: + c = d + else: + if c in ['x', 'X']: + c = self.read_numeric_escape(16, 2) + elif ord(c) in range(ord('0'), ord('7') + 1): + self.pos -= 1 + c = self.read_numeric_escape(8, 3) + + result.append(c) + if c == "\n": + self.newline() + + self.unexpected("unterminated string") + return "" + + def read_numeric_escape(self, radix: int, length: int) -> str: + value = -1 + end_pos = min(self.pos + length, len(self.data)) + + while self.pos < end_pos: + try: + digit = int(self.data[self.pos], radix) + except ValueError: + digit = -1 + + if digit == -1 or digit >= radix: + break + + if value < 0: + value = digit + else: + value = value * radix + digit + self.pos += 1 + + self.expect(value >= 0, "expected a digit after \\x or \\X") + return chr(value) + + def read_name(self) -> str: + """ Reads a (paren-wrapped), [square-wrapped] or naked symbol name. """ + + c = self.peek_char() + if c == '(': + self.pos += 1 + result = self.read_word() + self.expect(self.read_char() == ')', "expected ')'") + return result + if c == '[': + self.pos += 1 + result = self.read_word() + self.expect(self.read_char() == ']', "expected ']'") + return result + return self.read_word() + + def read_data_type(self) -> str: + """ Reads a scalar, map, or type name. """ + + name = self.read_word() + return self.read_data_type_by_name(name) + + def read_data_type_by_name(self, name: str) -> str: + """ Reads a scalar, map, or type name with `name` as a prefix word. """ + if name == "map": + self.expect(self.read_char() == '<', "expected '<'") + key_type = self.read_data_type() + + self.expect(self.read_char() == ',', "expected ','") + value_type = self.read_data_type() + + self.expect(self.read_char() == '>', "expected '>'") + return f"map<{key_type}, {value_type}>" + return name + + def read_word(self) -> str: + """ Reads a non-empty word and returns it. """ + self.skip_whitespace(True) + start = self.pos + while self.pos < len(self.data): + c = self.data[self.pos] + if ord(c) in range(ord('a'), ord('z') + 1) \ + or ord(c) in range(ord('A'), ord('Z') + 1) \ + or ord(c) in range(ord('0'), ord('9') + 1) or c in ['_', '-', '.']: + self.pos += 1 + else: + break + self.expect(start < self.pos, "expected a word") + return self.data[start:self.pos] + + def read_int(self) -> int: + """ Reads an integer and returns it. """ + tag: str = self.read_word() + try: + radix = 10 + if tag.startswith("0x") or tag.startswith("0X"): + radix = 16 + return int(tag, radix) + + except ValueError: + self.unexpected(f"expected an integer but was {tag}") + + def read_documentation(self) -> str: + """ Like skip_whitespace(), but this returns a string containing all comment text. By convention, + comments before a declaration document that declaration. """ + + result = None + while True: + self.skip_whitespace(False) + if self.pos == len(self.data) or self.data[self.pos] != '/': + if result: + return result + return "" + comment = self.read_comment() + if result: + result = f"{result}\n{comment}" + else: + result = f"{comment}" + + def read_comment(self) -> str: + """ Reads a comment and returns its body. """ + if self.pos == len(self.data) or self.data[self.pos] != '/': + raise IllegalStateException() + + self.pos += 1 + tval = -1 + if self.pos < len(self.data): + tval = ord(self.data[self.pos]) + self.pos += 1 + result: str = "" + if tval == ord('*'): + buffer = [] + start_of_line = True + while self.pos + 1 < len(self.data): + # pylint: disable=no-else-break + c: str = self.data[self.pos] + if c == '*' and self.data[self.pos + 1] == '/': + self.pos += 2 + result = "".join(buffer).strip() + break + elif c == "\n": + buffer.append("\n") + self.newline() + start_of_line = True + elif not start_of_line: + buffer.append(c) + elif c == "*": + if self.data[self.pos + 1] == ' ': + self.pos += 1 # Skip a single leading space, if present. + start_of_line = False + elif not c.isspace(): + buffer.append(c) + start_of_line = False + self.pos += 1 + if not result: + self.unexpected("unterminated comment") + elif tval == ord('/'): + if self.pos < len(self.data) and self.data[self.pos] == ' ': + self.pos += 1 # Skip a single leading space, if present. + start = self.pos + while self.pos < len(self.data): + c = self.data[self.pos] + self.pos += 1 + if c == "\n": + self.newline() + break + result = self.data[start:self.pos - 1] + if not result: + self.unexpected("unexpected '/'") + return result + + def try_append_trailing_documentation(self, documentation: str) -> str: + """ Search for a '/' character ignoring spaces and tabs.""" + while self.pos < len(self.data): + if self.data[self.pos] in [' ', "\t"]: + self.pos += 1 + elif self.data[self.pos] == '/': + self.pos += 1 + break + else: + # Not a whitespace or comment-starting character. Return original documentation. + return documentation + bval = (self.pos < len(self.data) and (self.data[self.pos] == '/' or self.data[self.pos] == '*')) + if not bval: + # Backtrack to start of comment. + self.pos -= 1 + self.expect(bval, "expected '//' or '/*'") + is_star = self.data[self.pos] == '*' + + self.pos += 1 + + # Skip a single leading space, if present. + if self.pos < len(self.data) and self.data[self.pos] == ' ': + self.pos += 1 + + start = self.pos + end: int + + if is_star: + # Consume star comment until it closes on the same line. + while True: + self.expect(self.pos < len(self.data), "trailing comment must be closed") + if self.data[self.pos] == '*' and self.pos + 1 < len(self.data) and self.data[self.pos + 1] == '/': + end = self.pos - 1 # The character before '*'. + self.pos += 2 # Skip to the character after '/'. + break + self.pos += 1 + # Ensure nothing follows a trailing star comment. + while self.pos < len(self.data): + c = self.data[self.pos] + self.pos += 1 + if c == "\n": + self.newline() + break + + self.expect(c in [" ", "\t"], "no syntax may follow trailing comment") + + else: + # Consume comment until newline. + while True: + if self.pos == len(self.data): + end = self.pos - 1 + break + c = self.data[self.pos] + self.pos += 1 + if c == "\n": + self.newline() + end = self.pos - 2 # Account for stepping past the newline. + break + + # Remove trailing whitespace. + while end > start and (self.data[end] == " " or self.data[end] == "\t"): + end -= 1 + + if end == start: + return documentation + + trailing_documentation = self.data[start:end + 1] + if not documentation.strip(): + return trailing_documentation + return f"{documentation}\n{trailing_documentation}" + + def skip_whitespace(self, skip_comments: bool) -> None: + """ Skips whitespace characters and optionally comments. When this returns, either + self.pos == self.data.length or a non-whitespace character. + """ + while self.pos < len(self.data): + c = self.data[self.pos] + if c in [" ", "\t", "\r", "\n"]: + self.pos += 1 + if c == "\n": + self.newline() + elif skip_comments and c == "/": + self.read_comment() + else: + return None + + def newline(self) -> None: + """ Call this every time a '\n' is encountered. """ + self.line += 1 + self.line_start = self.pos + + def location(self) -> Location: + return self._location.at(self.line + 1, self.pos - self.line_start + 1) + + def expect(self, condition: bool, message: str) -> None: + location = self.location() + if not condition: + self.unexpected(message, location) + + def expect_with_location(self, condition: bool, location: Location, message: str) -> None: + if not condition: + self.unexpected(message, location) + + def unexpected(self, message: str, location: Location = None) -> NoReturn: + if not location: + location = self.location() + w = f"Syntax error in {str(location)}: {message}" + raise IllegalStateException(w) diff --git a/karapace/protobuf/type_element.py b/karapace/protobuf/type_element.py new file mode 100644 index 000000000..10908e4ab --- /dev/null +++ b/karapace/protobuf/type_element.py @@ -0,0 +1,32 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/parser/TypeElement.kt +from dataclasses import dataclass +from karapace.protobuf.location import Location +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from karapace.protobuf.option_element import OptionElement + + +@dataclass +class TypeElement: + location: Location + name: str + documentation: str + options: List['OptionElement'] + nested_types: List['TypeElement'] + + def to_schema(self) -> str: + """Convert the object to valid protobuf syntax. + + This must be implemented by subclasses. + """ + raise NotImplementedError() + + def __repr__(self) -> str: + mytype = type(self) + return f"{mytype}({self.to_schema()})" + + def __str__(self) -> str: + mytype = type(self) + return f"{mytype}({self.to_schema()})" diff --git a/karapace/protobuf/utils.py b/karapace/protobuf/utils.py new file mode 100644 index 000000000..a9a6336c2 --- /dev/null +++ b/karapace/protobuf/utils.py @@ -0,0 +1,67 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/internal/Util.kt +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from karapace.protobuf.option_element import OptionElement + + +def append_documentation(data: List[str], documentation: str) -> None: + if not documentation: + return + + lines: list = documentation.split("\n") + + if len(lines) > 1 and not lines[-1]: + lines.pop() + + for line in lines: + data.append("// ") + data.append(line) + data.append("\n") + + +def append_options(data: List[str], options: List['OptionElement']) -> None: + count = len(options) + if count == 1: + data.append('[') + data.append(try_to_schema(options[0])) + data.append(']') + return + + data.append("[\n") + for i in range(0, count): + if i < count - 1: + endl = "," + else: + endl = "" + append_indented(data, try_to_schema(options[i]) + endl) + data.append(']') + + +def try_to_schema(obj: 'OptionElement') -> str: + try: + return obj.to_schema() + except AttributeError: + if isinstance(obj, str): + return obj + raise AttributeError + + +def append_indented(data: List[str], value: str) -> None: + lines = value.split("\n") + if len(lines) > 1 and not lines[-1]: + del lines[-1] + + for line in lines: + data.append(" ") + data.append(line) + data.append("\n") + + +MIN_TAG_VALUE = 1 +MAX_TAG_VALUE = ((1 << 29) & 0xffffffffffffffff) - 1 # 536,870,911 + +RESERVED_TAG_VALUE_START = 19000 +RESERVED_TAG_VALUE_END = 19999 +""" True if the supplied value is in the valid tag range and not reserved. """ diff --git a/karapace/rapu.py b/karapace/rapu.py index bb9101532..a552248ff 100644 --- a/karapace/rapu.py +++ b/karapace/rapu.py @@ -44,11 +44,13 @@ # TODO -> accept more general values as well REST_CONTENT_TYPE_RE = re.compile( - r"application/((vnd\.kafka(\.(?Pavro|json|binary|jsonschema))?(\.(?Pv[12]))?" + r"application/((vnd\.kafka(\.(?Pavro|json|protobuf|binary|jsonschema))?(\.(?Pv[12]))?" r"\+(?Pjson))|(?Pjson|octet-stream))" ) + REST_ACCEPT_RE = re.compile( - r"(application|\*)/((vnd\.kafka(\.(?Pavro|json|binary|jsonschema))?(\.(?Pv[12]))?\+" + r"(application|\*)/((vnd\.kafka(\.(?Pavro|json|" + r"protobuf|binary|jsonschema))?(\.(?Pv[12]))?\+" r"(?Pjson))|(?Pjson|\*))" ) @@ -195,7 +197,7 @@ def check_rest_headers(self, request: HTTPRequest) -> dict: # pylint:disable=in method = request.method default_content = "application/vnd.kafka.json.v2+json" default_accept = "*/*" - result: dict = {"content_type": default_content} + result = {"content_type": default_content} content_matcher = REST_CONTENT_TYPE_RE.search( cgi.parse_header(request.get_header("Content-Type", default_content))[0] ) diff --git a/karapace/schema_reader.py b/karapace/schema_reader.py index 3dd2245dd..154aefe52 100644 --- a/karapace/schema_reader.py +++ b/karapace/schema_reader.py @@ -14,11 +14,16 @@ from kafka.errors import NoBrokersAvailable, NodeNotReadyError, TopicAlreadyExistsError from karapace import constants from karapace.avro_compatibility import parse_avro_schema_definition +from karapace.protobuf.exception import ( + Error as ProtobufError, IllegalArgumentException, IllegalStateException, ProtobufException, + ProtobufParserRuntimeException, SchemaParseException as ProtobufSchemaParseException +) +from karapace.protobuf.schema import ProtobufSchema from karapace.statsd import StatsClient from karapace.utils import json_encode, KarapaceKafkaClient from queue import Queue from threading import Lock, Thread -from typing import Dict +from typing import Dict, Optional import json import logging @@ -38,6 +43,17 @@ def parse_jsonschema_definition(schema_definition: str) -> Draft7Validator: return Draft7Validator(schema) +def parse_protobuf_schema_definition(schema_definition: str) -> ProtobufSchema: + """ Parses and validates `schema_definition`. + + Raises: + Nothing yet. + + """ + + return ProtobufSchema(schema_definition) + + class InvalidSchema(Exception): pass @@ -71,12 +87,26 @@ def parse_avro(schema_str: str): # pylint: disable=inconsistent-return-statemen except (SchemaParseException, JSONDecodeError, TypeError) as e: raise InvalidSchema from e + @staticmethod + def parse_protobuf(schema_str: str) -> Optional['TypedSchema']: + try: + ts = TypedSchema(parse_protobuf_schema_definition(schema_str), SchemaType.PROTOBUF, schema_str) + return ts + except ( + TypeError, SchemaError, AssertionError, ProtobufParserRuntimeException, IllegalStateException, + IllegalArgumentException, ProtobufError, ProtobufException, ProtobufSchemaParseException + ) as e: + log.exception("Unexpected error: %s \n schema:[%s]", e, schema_str) + raise InvalidSchema from e + @staticmethod def parse(schema_type: SchemaType, schema_str: str): # pylint: disable=inconsistent-return-statements if schema_type is SchemaType.AVRO: return TypedSchema.parse_avro(schema_str) if schema_type is SchemaType.JSONSCHEMA: return TypedSchema.parse_json(schema_str) + if schema_type is SchemaType.PROTOBUF: + return TypedSchema.parse_protobuf(schema_str) raise InvalidSchema(f"Unknown parser {schema_type} for {schema_str}") def to_json(self): @@ -84,12 +114,18 @@ def to_json(self): return self.schema.schema if isinstance(self.schema, AvroSchema): return self.schema.to_json(names=None) + if isinstance(self.schema, ProtobufSchema): + raise InvalidSchema("Protobuf do not support to_json serialization") return self.schema def __str__(self) -> str: + if isinstance(self.schema, ProtobufSchema): + return str(self.schema) return json_encode(self.to_json(), compact=True) def __repr__(self): + if isinstance(self.schema, ProtobufSchema): + return f"TypedSchema(type={self.schema_type}, schema={str(self)})" return f"TypedSchema(type={self.schema_type}, schema={json_encode(self.to_json())})" def __eq__(self, other): diff --git a/karapace/schema_registry_apis.py b/karapace/schema_registry_apis.py index 787cf4c36..d6c4875a5 100644 --- a/karapace/schema_registry_apis.py +++ b/karapace/schema_registry_apis.py @@ -276,6 +276,7 @@ def send_delete_subject_message(self, subject, version): value = '{{"subject":"{}","version":{}}}'.format(subject, version) return self.send_kafka_message(key, value) + # protobuf compatibility_check async def compatibility_check(self, content_type, *, subject, version, request): """Check for schema compatibility""" body = request.json @@ -377,7 +378,7 @@ async def schemas_get_versions(self, content_type, *, schema_id): self.r(subject_versions, content_type) async def schemas_types(self, content_type): - self.r(["JSON", "AVRO"], content_type) + self.r(["JSON", "AVRO", "PROTOBUF"], content_type) async def config_get(self, content_type): # Note: The format sent by the user differs from the return value, this @@ -669,7 +670,7 @@ def _validate_schema_request_body(self, content_type, body) -> None: def _validate_schema_type(self, content_type, body) -> None: schema_type = SchemaType(body.get("schemaType", SchemaType.AVRO.value)) - if schema_type not in {SchemaType.JSONSCHEMA, SchemaType.AVRO}: + if schema_type not in {SchemaType.JSONSCHEMA, SchemaType.AVRO, SchemaType.PROTOBUF}: self.r( body={ "error_code": SchemaErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value, @@ -845,10 +846,17 @@ def write_new_schema_local(self, subject, body, content_type): # We didn't find an existing schema and the schema is compatible so go and create one schema_id = self.ksr.get_schema_id(new_schema) version = max(self.ksr.subjects[subject]["schemas"]) + 1 - self.log.info( - "Registering subject: %r, id: %r new version: %r with schema %r, schema_id: %r", subject, schema_id, version, - new_schema.to_json(), schema_id - ) + if new_schema.schema_type is SchemaType.PROTOBUF: + self.log.info( + "Registering subject: %r, id: %r new version: %r with schema %r, schema_id: %r", subject, schema_id, + version, new_schema.__str__(), schema_id + ) + else: + self.log.info( + "Registering subject: %r, id: %r new version: %r with schema %r, schema_id: %r", subject, schema_id, + version, new_schema.to_json(), schema_id + ) + self.send_schema_message( subject=subject, schema=new_schema, diff --git a/karapace/serialization.py b/karapace/serialization.py index bff12edc5..dbec23fdd 100644 --- a/karapace/serialization.py +++ b/karapace/serialization.py @@ -1,6 +1,9 @@ from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter +from google.protobuf.message import DecodeError from json import load from jsonschema import ValidationError +from karapace.protobuf.exception import ProtobufTypeException +from karapace.protobuf.io import ProtobufDatumReader, ProtobufDatumWriter from karapace.schema_reader import InvalidSchema, SchemaType, TypedSchema from karapace.utils import Client, json_encode from typing import Dict, Optional @@ -70,7 +73,10 @@ def __init__(self, schema_registry_url: str = "http://localhost:8081"): self.base_url = schema_registry_url async def post_new_schema(self, subject: str, schema: TypedSchema) -> int: - payload = {"schema": json_encode(schema.to_json()), "schemaType": schema.schema_type.value} + if schema.schema_type is SchemaType.PROTOBUF: + payload = {"schema": str(schema), "schemaType": schema.schema_type.value} + else: + payload = {"schema": json_encode(schema.to_json()), "schemaType": schema.schema_type.value} result = await self.client.post(f"subjects/{quote(subject)}/versions", json=payload) if not result.ok: raise SchemaRetrievalError(result.json()) @@ -134,6 +140,9 @@ def get_subject_name(self, topic_name: str, schema: str, subject_type: str, sche namespace = schema_typed.schema.namespace if schema_type is SchemaType.JSONSCHEMA: namespace = schema_typed.to_json().get("namespace", "dummy") + # Protobuf does not use namespaces in terms of AVRO + if schema_type is SchemaType.PROTOBUF: + namespace = "" return f"{self.subject_name_strategy(topic_name, namespace)}-{subject_type}" async def get_schema_for_subject(self, subject: str) -> TypedSchema: @@ -183,10 +192,18 @@ def read_value(schema: TypedSchema, bio: io.BytesIO): except ValidationError as e: raise InvalidPayload from e return value + + if schema.schema_type is SchemaType.PROTOBUF: + try: + reader = ProtobufDatumReader(schema.schema) + return reader.read(bio) + except DecodeError as e: + raise InvalidPayload from e + raise ValueError("Unknown schema type") -def write_value(schema: TypedSchema, bio: io.BytesIO, value: dict): +def write_value(schema: TypedSchema, bio: io.BytesIO, value: dict) -> None: if schema.schema_type is SchemaType.AVRO: writer = DatumWriter(schema.schema) writer.write(value, BinaryEncoder(bio)) @@ -196,6 +213,13 @@ def write_value(schema: TypedSchema, bio: io.BytesIO, value: dict): except ValidationError as e: raise InvalidPayload from e bio.write(json_encode(value, binary=True)) + + elif schema.schema_type is SchemaType.PROTOBUF: + # TODO: PROTOBUF* we need use protobuf validator there + writer = ProtobufDatumWriter(schema.schema) + writer.write_index(bio) + writer.write(value, bio) + else: raise ValueError("Unknown schema type") @@ -208,6 +232,8 @@ async def serialize(self, schema: TypedSchema, value: dict) -> bytes: try: write_value(schema, bio, value) return bio.getvalue() + except ProtobufTypeException as e: + raise InvalidMessageSchema("Object does not fit to stored schema") from e except avro.io.AvroTypeException as e: raise InvalidMessageSchema("Object does not fit to stored schema") from e @@ -222,9 +248,11 @@ async def deserialize(self, bytes_: bytes) -> dict: raise InvalidMessageHeader("Start byte is %x and should be %x" % (start_byte, START_BYTE)) try: schema = await self.get_schema_for_id(schema_id) + if schema is None: + raise InvalidPayload("No schema with ID from payload") ret_val = read_value(schema, bio) return ret_val except AssertionError as e: - raise InvalidPayload(f"Data does not contain a valid {schema.schema_type} message") from e + raise InvalidPayload("Data does not contain a valid message") from e except avro.io.SchemaResolutionException as e: raise InvalidPayload("Data cannot be decoded with provided schema") from e diff --git a/requirements.txt b/requirements.txt index d4791526c..6732e0940 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,8 @@ jsonschema==3.2.0 lz4==3.0.2 requests==2.23.0 networkx==2.5 +python-dateutil==2.8.2 +protobuf~=3.14.0 # Patched dependencies # @@ -16,4 +18,4 @@ networkx==2.5 # images and forces a new image generation. # git+https://github.com/aiven/avro.git@513b153bac5040af6bba5847aef202adb680b67b#subdirectory=lang/py3/ -git+git://github.com/aiven/kafka-python.git@b9f2f78377d56392f61cba8856dc6c02ae841b79 +git+https://github.com/aiven/kafka-python.git@b9f2f78377d56392f61cba8856dc6c02ae841b79 diff --git a/runtime/.gitignore b/runtime/.gitignore new file mode 100644 index 000000000..5e7d2734c --- /dev/null +++ b/runtime/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/tests/integration/test_client_protobuf.py b/tests/integration/test_client_protobuf.py new file mode 100644 index 000000000..e000defa3 --- /dev/null +++ b/tests/integration/test_client_protobuf.py @@ -0,0 +1,34 @@ +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.schema_reader import SchemaType, TypedSchema +from karapace.serialization import SchemaRegistryClient +from tests.schemas.protobuf import schema_protobuf_order_after, schema_protobuf_order_before, schema_protobuf_plain +from tests.utils import new_random_name + + +async def test_remote_client_protobuf(registry_async_client): + schema_protobuf = TypedSchema.parse(SchemaType.PROTOBUF, schema_protobuf_plain) + reg_cli = SchemaRegistryClient() + reg_cli.client = registry_async_client + subject = new_random_name("subject") + sc_id = await reg_cli.post_new_schema(subject, schema_protobuf) + assert sc_id >= 0 + stored_schema = await reg_cli.get_schema_for_id(sc_id) + assert stored_schema == schema_protobuf, f"stored schema {stored_schema} is not {schema_protobuf}" + stored_id, stored_schema = await reg_cli.get_latest_schema(subject) + assert stored_id == sc_id + assert stored_schema == schema_protobuf + + +async def test_remote_client_protobuf2(registry_async_client): + schema_protobuf = TypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf_order_before)) + schema_protobuf_after = TypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf_order_after)) + reg_cli = SchemaRegistryClient() + reg_cli.client = registry_async_client + subject = new_random_name("subject") + sc_id = await reg_cli.post_new_schema(subject, schema_protobuf) + assert sc_id >= 0 + stored_schema = await reg_cli.get_schema_for_id(sc_id) + assert stored_schema == schema_protobuf, f"stored schema {stored_schema} is not {schema_protobuf}" + stored_id, stored_schema = await reg_cli.get_latest_schema(subject) + assert stored_id == sc_id + assert stored_schema == schema_protobuf_after diff --git a/tests/integration/test_rest_consumer_protobuf.py b/tests/integration/test_rest_consumer_protobuf.py new file mode 100644 index 000000000..94695e952 --- /dev/null +++ b/tests/integration/test_rest_consumer_protobuf.py @@ -0,0 +1,77 @@ +from tests.utils import ( + new_consumer, new_topic, repeat_until_successful_request, REST_HEADERS, schema_data, schema_data_second +) + +import pytest + + +@pytest.mark.parametrize("schema_type", ["protobuf"]) +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_publish_consume_protobuf(rest_async_client, admin_client, trail, schema_type): + header = REST_HEADERS[schema_type] + group_name = "e2e_protobuf_group" + instance_id = await new_consumer(rest_async_client, group_name, fmt=schema_type, trail=trail) + assign_path = f"/consumers/{group_name}/instances/{instance_id}/assignments{trail}" + consume_path = f"/consumers/{group_name}/instances/{instance_id}/records{trail}?timeout=1000" + tn = new_topic(admin_client) + assign_payload = {"partitions": [{"topic": tn, "partition": 0}]} + res = await rest_async_client.post(assign_path, json=assign_payload, headers=header) + assert res.ok + publish_payload = schema_data[schema_type][1] + await repeat_until_successful_request( + rest_async_client.post, + f"topics/{tn}{trail}", + json_data={ + "value_schema": schema_data[schema_type][0], + "records": [{ + "value": o + } for o in publish_payload] + }, + headers=header, + error_msg="Unexpected response status for offset commit", + timeout=10, + sleep=1, + ) + resp = await rest_async_client.get(consume_path, headers=header) + assert resp.ok, f"Expected a successful response: {resp}" + data = resp.json() + assert len(data) == len(publish_payload), f"Expected to read test_objects from fetch request but got {data}" + data_values = [x["value"] for x in data] + for expected, actual in zip(publish_payload, data_values): + assert expected == actual, f"Expecting {actual} to be {expected}" + + +@pytest.mark.parametrize("schema_type", ["protobuf"]) +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_publish_consume_protobuf_second(rest_async_client, admin_client, trail, schema_type): + header = REST_HEADERS[schema_type] + group_name = "e2e_proto_second" + instance_id = await new_consumer(rest_async_client, group_name, fmt=schema_type, trail=trail) + assign_path = f"/consumers/{group_name}/instances/{instance_id}/assignments{trail}" + consume_path = f"/consumers/{group_name}/instances/{instance_id}/records{trail}?timeout=1000" + tn = new_topic(admin_client) + assign_payload = {"partitions": [{"topic": tn, "partition": 0}]} + res = await rest_async_client.post(assign_path, json=assign_payload, headers=header) + assert res.ok + publish_payload = schema_data_second[schema_type][1] + await repeat_until_successful_request( + rest_async_client.post, + f"topics/{tn}{trail}", + json_data={ + "value_schema": schema_data_second[schema_type][0], + "records": [{ + "value": o + } for o in publish_payload] + }, + headers=header, + error_msg="Unexpected response status for offset commit", + timeout=10, + sleep=1, + ) + resp = await rest_async_client.get(consume_path, headers=header) + assert resp.ok, f"Expected a successful response: {resp}" + data = resp.json() + assert len(data) == len(publish_payload), f"Expected to read test_objects from fetch request but got {data}" + data_values = [x["value"] for x in data] + for expected, actual in zip(publish_payload, data_values): + assert expected == actual, f"Expecting {actual} to be {expected}" diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 70827ef38..5f78cb7a5 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -1132,9 +1132,10 @@ async def test_schema_types(registry_async_client: Client, trail: str) -> None: res = await registry_async_client.get(f"/schemas/types{trail}") assert res.status_code == 200 json = res.json() - assert len(json) == 2 + assert len(json) == 3 assert "AVRO" in json assert "JSON" in json + assert "PROTOBUF" in json @pytest.mark.parametrize("trail", ["", "/"]) diff --git a/tests/integration/test_schema_protobuf.py b/tests/integration/test_schema_protobuf.py new file mode 100644 index 000000000..327449acb --- /dev/null +++ b/tests/integration/test_schema_protobuf.py @@ -0,0 +1,193 @@ +""" +karapace - schema tests + +Copyright (c) 2019 Aiven Ltd +See LICENSE for details +""" +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.utils import Client +from tests.utils import create_subject_name_factory + +import json +import logging +import pytest +import requests + +baseurl = "http://localhost:8081" + +compatibility_test_url = "https://raw.githubusercontent.com/confluentinc/schema-registry/" + \ + "0530b0107749512b997f49cc79fe423f21b43b87/" + \ + "protobuf-provider/src/test/resources/diff-schema-examples.json" + + +def add_slashes(text: str) -> str: + escape_dict = { + '\a': '\\a', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', + '\v': '\\v', + '\'': "\\'", + '\"': '\\"', + '\\': '\\\\' + } + trans_table = str.maketrans(escape_dict) + return text.translate(trans_table) + + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("trail", ["", "/"]) +async def test_protobuf_schema_compatibility(registry_async_client: Client, trail: str) -> None: + subject = create_subject_name_factory(f"test_protobuf_schema_compatibility-{trail}")() + + res = await registry_async_client.put(f"config/{subject}{trail}", json={"compatibility": "BACKWARD"}) + assert res.status == 200 + + original_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str2 = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + original_schema = trim_margin(original_schema) + + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", json={ + "schemaType": "PROTOBUF", + "schema": original_schema + } + ) + assert res.status == 200 + assert "id" in res.json() + + evolved_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str2 = 1; + | Enu x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + | enum Enu { + | A = 0; + | B = 1; + | } + |} + |""" + evolved_schema = trim_margin(evolved_schema) + + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={ + "schemaType": "PROTOBUF", + "schema": evolved_schema + }, + ) + assert res.status == 200 + assert res.json() == {"is_compatible": True} + + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", json={ + "schemaType": "PROTOBUF", + "schema": evolved_schema + } + ) + assert res.status == 200 + assert "id" in res.json() + + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={ + "schemaType": "PROTOBUF", + "schema": original_schema + }, + ) + assert res.json() == {"is_compatible": True} + assert res.status == 200 + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", json={ + "schemaType": "PROTOBUF", + "schema": original_schema + } + ) + assert res.status == 200 + assert "id" in res.json() + + +class Schemas: + url = requests.get(compatibility_test_url) + sch = json.loads(url.text) + schemas = {} + descriptions = [] + max_count = 120 + count = 0 + for a in sch: + # if a["description"] == "Detect compatible add field to oneof": + descriptions.append(a["description"]) + schemas[a["description"]] = dict(a) + count += 1 + if a["description"] == 'Detect incompatible message index change': + break + if count == max_count: + break + + +@pytest.mark.parametrize("trail", ["", "/"]) +@pytest.mark.parametrize("desc", Schemas.descriptions) +async def test_schema_registry_examples(registry_async_client: Client, trail: str, desc) -> None: + subject = create_subject_name_factory(f"test_protobuf_schema_compatibility-{trail}")() + + res = await registry_async_client.put(f"config/{subject}{trail}", json={"compatibility": "BACKWARD"}) + assert res.status == 200 + + description = desc + + schema = Schemas.schemas[description] + original_schema = schema["original_schema"] + evolved_schema = schema["update_schema"] + compatible = schema["compatible"] + + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", json={ + "schemaType": "PROTOBUF", + "schema": original_schema + } + ) + assert res.status == 200 + assert "id" in res.json() + + res = await registry_async_client.post( + f"compatibility/subjects/{subject}/versions/latest{trail}", + json={ + "schemaType": "PROTOBUF", + "schema": evolved_schema + }, + ) + assert res.status == 200 + assert res.json() == {"is_compatible": compatible} + + res = await registry_async_client.post( + f"subjects/{subject}/versions{trail}", json={ + "schemaType": "PROTOBUF", + "schema": evolved_schema + } + ) + + if compatible: + assert res.status == 200 + assert "id" in res.json() + else: + assert res.status == 409 diff --git a/tests/schemas/protobuf.py b/tests/schemas/protobuf.py new file mode 100644 index 000000000..02f5d6a72 --- /dev/null +++ b/tests/schemas/protobuf.py @@ -0,0 +1,73 @@ +schema_protobuf_plain = """syntax = "proto3"; +package com.codingharbour.protobuf; + +option java_outer_classname = "SimpleMessageProtos"; +message SimpleMessage { + string content = 1; + string date_time = 2; + string content2 = 3; +} +""" + +schema_protobuf_schema_registry1 = """ +|syntax = "proto3"; +|package com.codingharbour.protobuf; +| +|message SimpleMessage { +| string content = 1; +| string my_string = 2; +| int32 my_int = 3; +|} +| +""" + +schema_protobuf_order_before = """ +|syntax = "proto3"; +| +|option java_package = "com.codingharbour.protobuf"; +|option java_outer_classname = "TestEnumOrder"; +| +|enum Enum { +| HIGH = 0; +| MIDDLE = 1; +| LOW = 2; +|} +|message Message { +| int32 query = 1; +|} +""" + +schema_protobuf_order_after = """ +|syntax = "proto3"; +| +|option java_package = "com.codingharbour.protobuf"; +|option java_outer_classname = "TestEnumOrder"; +| +|message Message { +| int32 query = 1; +|} +|enum Enum { +| HIGH = 0; +| MIDDLE = 1; +| LOW = 2; +|} +| +""" + +schema_protobuf_compare_one = """ +|syntax = "proto3"; +| +|option java_package = "com.codingharbour.protobuf"; +|option java_outer_classname = "TestEnumOrder"; +| +|message Message { +| int32 query = 1; +| string content = 2; +|} +|enum Enum { +| HIGH = 0; +| MIDDLE = 1; +| LOW = 2; +|} +| +""" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 70c8034de..ff800660e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,5 +1,6 @@ +from karapace.protobuf.kotlin_wrapper import trim_margin from karapace.schema_reader import SchemaType, TypedSchema -from tests.utils import schema_avro_json +from tests.utils import schema_avro_json, schema_protobuf, schema_protobuf2 import pytest @@ -19,6 +20,31 @@ async def post_new_schema(self, *args, **kwargs): return 1 +class MockProtobufClient: + # pylint: disable=unused-argument + def __init__(self, *args, **kwargs): + pass + + async def get_schema_for_id2(self, *args, **kwargs): + return TypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf2)) + + async def get_schema_for_id(self, *args, **kwargs): + if args[0] != 1: + return None + return TypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)) + + async def get_latest_schema(self, *args, **kwargs): + return 1, TypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)) + + async def post_new_schema(self, *args, **kwargs): + return 1 + + @pytest.fixture(name="mock_registry_client") def create_basic_registry_client() -> MockClient: return MockClient() + + +@pytest.fixture(name="mock_protobuf_registry_client") +def create_basic_protobuf_registry_client() -> MockProtobufClient: + return MockProtobufClient() diff --git a/tests/unit/test_any_tool.py b/tests/unit/test_any_tool.py new file mode 100644 index 000000000..04a7a60da --- /dev/null +++ b/tests/unit/test_any_tool.py @@ -0,0 +1,70 @@ +from karapace import config +from karapace.protobuf.io import calculate_class_name +from karapace.protobuf.kotlin_wrapper import trim_margin +from subprocess import PIPE, Popen, TimeoutExpired + +import importlib +import importlib.util +import logging + +log = logging.getLogger("KarapaceTests") + + +def test_protoc(): + + proto: str = """ + |syntax = "proto3"; + |package com.instaclustr.protobuf; + |option java_outer_classname = "SimpleMessageProtos"; + |message SimpleMessage { + | string content = 1; + | string date_time = 2; + | string content2 = 3; + |} + | + """ + proto = trim_margin(proto) + + directory = config.DEFAULTS["protobuf_runtime_directory"] + proto_name = calculate_class_name(str(proto)) + proto_path = f"{directory}/{proto_name}.proto" + class_path = f"{directory}/{proto_name}_pb2.py" + + log.info(proto_name) + try: + with open(proto_path, "w") as proto_text: + proto_text.write(str(proto)) + proto_text.close() + + except Exception as e: # pylint: disable=broad-except + log.error("Unexpected exception in statsd send: %s: %s", e.__class__.__name__, e) + assert False, f"Cannot write Proto File. Unexpected exception in statsd send: {e.__class__.__name__} + {e}" + + args = ["protoc", "--python_out=./", proto_path] + try: + proc = Popen(args, stdout=PIPE, stderr=PIPE, shell=False) + except FileNotFoundError as e: + assert False, f"Protoc not found. {e}" + except Exception as e: # pylint: disable=broad-except + log.error("Unexpected exception in statsd send: %s: %s", e.__class__.__name__, e) + assert False, f"Cannot execute protoc. Unexpected exception in statsd send: {e.__class__.__name__} + {e}" + try: + out, err = proc.communicate(timeout=10) + assert out == b'' + assert err == b'' + except TimeoutExpired: + proc.kill() + assert False, "Timeout expired" + module_content = "" + try: + with open(class_path, "r") as proto_text: + module_content = proto_text.read() + proto_text.close() + print(module_content) + except Exception as e: # pylint: disable=broad-except + log.error("Unexpected exception in statsd send: %s: %s", e.__class__.__name__, e) + assert False, f"Cannot read Proto File. Unexpected exception in statsd send: {e.__class__.__name__} + {e}" + + spec = importlib.util.spec_from_file_location(f"{proto_name}_pb2", class_path) + tmp_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(tmp_module) diff --git a/tests/unit/test_compare_elements.py b/tests/unit/test_compare_elements.py new file mode 100644 index 000000000..6c236229a --- /dev/null +++ b/tests/unit/test_compare_elements.py @@ -0,0 +1,76 @@ +from karapace.protobuf.compare_result import CompareResult, Modification +from karapace.protobuf.compare_type_storage import CompareTypes +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.location import Location +from karapace.protobuf.one_of_element import OneOfElement +from karapace.protobuf.option_element import OptionElement + +location: Location = Location.get("some/folder", "file.proto") + + +def test_compare_oneof(): + self_one_of = OneOfElement( + name="page_info", + fields=[ + FieldElement(location=location.at(4, 5), element_type="int32", name="page_number", tag=2), + FieldElement(location=location.at(5, 5), element_type="int32", name="result_per_page", tag=3) + ], + ) + + other_one_of = OneOfElement( + name="info", + fields=[ + FieldElement(location=location.at(4, 5), element_type="int32", name="page_number", tag=2), + FieldElement(location=location.at(5, 5), element_type="int32", name="result_per_page", tag=3), + FieldElement(location=location.at(6, 5), element_type="int32", name="view", tag=4) + ], + ) + + result = CompareResult() + types = CompareTypes('', '', result) + self_one_of.compare(other_one_of, result, types) + assert result.is_compatible() + assert len(result.result) == 1 + result2: list = [] + for e in result.result: + result2.append(e.modification) + assert Modification.ONE_OF_FIELD_ADD in result2 + + +def test_compare_field(): + self_field = FieldElement( + location=location.at(4, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="test", + tag=3, + options=[ + OptionElement("old_default", OptionElement.Kind.BOOLEAN, "true"), + OptionElement("delay", OptionElement.Kind.NUMBER, "200", True) + ] + ) + + other_field = FieldElement( + location=location.at(4, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="best", + tag=3, + options=[ + OptionElement("old_default", OptionElement.Kind.BOOLEAN, "true"), + OptionElement("delay", OptionElement.Kind.NUMBER, "200", True) + ] + ) + + result = CompareResult() + types = CompareTypes('', '', result) + self_field.compare(other_field, result, types) + + assert result.is_compatible() + assert len(result.result) == 1 + result2: list = [] + for e in result.result: + result2.append(e.modification) + + assert Modification.FIELD_NAME_ALTER in result2 diff --git a/tests/unit/test_compatibility.py b/tests/unit/test_compatibility.py new file mode 100644 index 000000000..8348a1b51 --- /dev/null +++ b/tests/unit/test_compatibility.py @@ -0,0 +1,186 @@ +from karapace.protobuf.compare_result import CompareResult +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.proto_file_element import ProtoFileElement +from karapace.protobuf.proto_parser import ProtoParser + +location: Location = Location.get("some/folder", "file.proto") + + +def test_compatibility_package(): + self_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str = 1; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + other_schema = """ + |syntax = "proto3"; + |package a2; + |message TestMessage { + | message Value { + | string str = 1; + | } + | string test = 1; + | .a2.TestMessage.Value val = 2; + |} + |""" + + self_schema = trim_margin(self_schema) + other_schema = trim_margin(other_schema) + self_parsed: ProtoFileElement = ProtoParser.parse(location, self_schema) + other_parsed: ProtoFileElement = ProtoParser.parse(location, other_schema) + result = CompareResult() + self_parsed.compare(other_parsed, result) + assert result.is_compatible() + + +def test_compatibility_field_add(): + self_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str = 1; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + other_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str = 1; + | string str2 = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + self_schema = trim_margin(self_schema) + other_schema = trim_margin(other_schema) + self_parsed: ProtoFileElement = ProtoParser.parse(location, self_schema) + other_parsed: ProtoFileElement = ProtoParser.parse(location, other_schema) + result = CompareResult() + self_parsed.compare(other_parsed, result) + assert result.is_compatible() + + +def test_compatibility_field_drop(): + self_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str = 1; + | string str2 = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + other_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str = 1; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + self_schema = trim_margin(self_schema) + other_schema = trim_margin(other_schema) + self_parsed: ProtoFileElement = ProtoParser.parse(location, self_schema) + other_parsed: ProtoFileElement = ProtoParser.parse(location, other_schema) + result = CompareResult() + self_parsed.compare(other_parsed, result) + assert result.is_compatible() + + +def test_compatibility_field_add_drop(): + self_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str2 = 1; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + other_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str = 1; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + self_schema = trim_margin(self_schema) + other_schema = trim_margin(other_schema) + self_parsed: ProtoFileElement = ProtoParser.parse(location, self_schema) + other_parsed: ProtoFileElement = ProtoParser.parse(location, other_schema) + result = CompareResult() + self_parsed.compare(other_parsed, result) + assert result.is_compatible() + + +def test_compatibility_enum_add(): + self_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str2 = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + other_schema = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str2 = 1; + | Enu x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + | enum Enu { + | A = 0; + | B = 1; + | } + |} + |""" + + self_schema = trim_margin(self_schema) + other_schema = trim_margin(other_schema) + self_parsed: ProtoFileElement = ProtoParser.parse(location, self_schema) + other_parsed: ProtoFileElement = ProtoParser.parse(location, other_schema) + + result = CompareResult() + self_parsed.compare(other_parsed, result) + assert result.is_compatible() diff --git a/tests/unit/test_enum_element.py b/tests/unit/test_enum_element.py new file mode 100644 index 000000000..f098f2d63 --- /dev/null +++ b/tests/unit/test_enum_element.py @@ -0,0 +1,139 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/EnumElementTest.kt + +from karapace.protobuf.enum_constant_element import EnumConstantElement +from karapace.protobuf.enum_element import EnumElement +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement + +location: Location = Location.get("file.proto") + + +def test_empty_to_schema(): + element = EnumElement(location=location, name="Enum") + expected = "enum Enum {}\n" + assert element.to_schema() == expected + + +def test_simple_to_schema(): + element = EnumElement( + location=location, + name="Enum", + constants=[ + EnumConstantElement(location=location, name="ONE", tag=1), + EnumConstantElement(location=location, name="TWO", tag=2), + EnumConstantElement(location=location, name="SIX", tag=6) + ] + ) + expected = """ + |enum Enum { + | ONE = 1; + | TWO = 2; + | SIX = 6; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_constants(): + one = EnumConstantElement(location=location, name="ONE", tag=1) + two = EnumConstantElement(location=location, name="TWO", tag=2) + six = EnumConstantElement(location=location, name="SIX", tag=6) + element = EnumElement(location=location, name="Enum", constants=[one, two, six]) + assert len(element.constants) == 3 + + +def test_simple_with_options_to_schema(): + element = EnumElement( + location=location, + name="Enum", + options=[OptionElement("kit", OptionElement.Kind.STRING, "kat")], + constants=[ + EnumConstantElement(location=location, name="ONE", tag=1), + EnumConstantElement(location=location, name="TWO", tag=2), + EnumConstantElement(location=location, name="SIX", tag=6) + ] + ) + expected = """ + |enum Enum { + | option kit = "kat"; + | ONE = 1; + | TWO = 2; + | SIX = 6; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_options(): + kit_kat = OptionElement("kit", OptionElement.Kind.STRING, "kat") + foo_bar = OptionElement("foo", OptionElement.Kind.STRING, "bar") + element = EnumElement( + location=location, + name="Enum", + options=[kit_kat, foo_bar], + constants=[EnumConstantElement(location=location, name="ONE", tag=1)] + ) + assert len(element.options) == 2 + + +def test_simple_with_documentation_to_schema(): + element = EnumElement( + location=location, + name="Enum", + documentation="Hello", + constants=[ + EnumConstantElement(location=location, name="ONE", tag=1), + EnumConstantElement(location=location, name="TWO", tag=2), + EnumConstantElement(location=location, name="SIX", tag=6) + ] + ) + expected = """ + |// Hello + |enum Enum { + | ONE = 1; + | TWO = 2; + | SIX = 6; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_field_to_schema(): + value = EnumConstantElement(location=location, name="NAME", tag=1) + expected = "NAME = 1;\n" + assert value.to_schema() == expected + + +def test_field_with_documentation_to_schema(): + value = EnumConstantElement(location=location, name="NAME", tag=1, documentation="Hello") + expected = """ + |// Hello + |NAME = 1; + |""" + expected = trim_margin(expected) + assert value.to_schema() == expected + + +def test_field_with_options_to_schema(): + value = EnumConstantElement( + location=location, + name="NAME", + tag=1, + options=[ + OptionElement("kit", OptionElement.Kind.STRING, "kat", True), + OptionElement("tit", OptionElement.Kind.STRING, "tat") + ] + ) + expected = """ + |NAME = 1 [ + | (kit) = "kat", + | tit = "tat" + |]; + |""" + expected = trim_margin(expected) + assert value.to_schema() == expected diff --git a/tests/unit/test_extend_element.py b/tests/unit/test_extend_element.py new file mode 100644 index 000000000..7c718a8b0 --- /dev/null +++ b/tests/unit/test_extend_element.py @@ -0,0 +1,105 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/ExtendElementTest.kt + +from karapace.protobuf.extend_element import ExtendElement +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location + +location = Location.get("file.proto") + + +def test_empty_to_schema(): + extend = ExtendElement(location=location, name="Name") + expected = "extend Name {}\n" + assert extend.to_schema() == expected + + +def test_simple_to_schema(): + extend = ExtendElement( + location=location, + name="Name", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)] + ) + expected = """ + |extend Name { + | required string name = 1; + |} + |""" + expected = trim_margin(expected) + assert extend.to_schema() == expected + + +def test_add_multiple_fields(): + first_name = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="first_name", tag=1) + last_name = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="last_name", tag=2) + extend = ExtendElement(location=location, name="Name", fields=[first_name, last_name]) + assert len(extend.fields) == 2 + + +def test_simple_with_documentation_to_schema(): + extend = ExtendElement( + location=location, + name="Name", + documentation="Hello", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)] + ) + expected = """ + |// Hello + |extend Name { + | required string name = 1; + |} + |""" + expected = trim_margin(expected) + assert extend.to_schema() == expected + + +def test_json_name_to_schema(): + extend = ExtendElement( + location=location, + name="Name", + fields=[ + FieldElement( + location=location, + label=Field.Label.REQUIRED, + element_type="string", + name="name", + json_name="my_json", + tag=1 + ) + ] + ) + expected = """ + |extend Name { + | required string name = 1 [json_name = "my_json"]; + |} + |""" + expected = trim_margin(expected) + assert extend.to_schema() == expected + + +def test_default_is_set_in_proto2_file(): + extend = ExtendElement( + location=location, + name="Name", + documentation="Hello", + fields=[ + FieldElement( + location=location, + label=Field.Label.REQUIRED, + element_type="string", + name="name", + tag=1, + default_value="defaultValue" + ) + ] + ) + expected = """ + |// Hello + |extend Name { + | required string name = 1 [default = "defaultValue"]; + |} + |""" + expected = trim_margin(expected) + assert extend.to_schema() == expected diff --git a/tests/unit/test_extensions_element.py b/tests/unit/test_extensions_element.py new file mode 100644 index 000000000..0b40c9251 --- /dev/null +++ b/tests/unit/test_extensions_element.py @@ -0,0 +1,37 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/ExtensionsElementTest.kt + +from karapace.protobuf.extensions_element import ExtensionsElement +from karapace.protobuf.kotlin_wrapper import KotlinRange, trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.utils import MAX_TAG_VALUE + +location = Location.get("file.proto") + + +def test_single_value_to_schema(): + actual = ExtensionsElement(location=location, values=[500]) + expected = "extensions 500;\n" + assert actual.to_schema() == expected + + +def test_range_to_schema(): + actual = ExtensionsElement(location=location, values=[KotlinRange(500, 505)]) + expected = "extensions 500 to 505;\n" + assert actual.to_schema() == expected + + +def test_max_range_to_schema(): + actual = ExtensionsElement(location=location, values=[KotlinRange(500, MAX_TAG_VALUE)]) + expected = "extensions 500 to max;\n" + assert actual.to_schema() == expected + + +def test_with_documentation_to_schema(): + actual = ExtensionsElement(location=location, documentation="Hello", values=[500]) + expected = """ + |// Hello + |extensions 500; + |""" + expected = trim_margin(expected) + assert actual.to_schema() == expected diff --git a/tests/unit/test_field_element.py b/tests/unit/test_field_element.py new file mode 100644 index 000000000..9f38093d6 --- /dev/null +++ b/tests/unit/test_field_element.py @@ -0,0 +1,86 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/FieldElementTest.kt + +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement + +location = Location.get("file.proto") + + +def test_field(): + field = FieldElement( + location=location, + label=Field.Label.OPTIONAL, + element_type="CType", + name="ctype", + tag=1, + options=[ + OptionElement("default", OptionElement.Kind.ENUM, "TEST"), + OptionElement("deprecated", OptionElement.Kind.BOOLEAN, "true") + ] + ) + + assert len(field.options) == 2 + assert OptionElement("default", OptionElement.Kind.ENUM, "TEST") in field.options + assert OptionElement("deprecated", OptionElement.Kind.BOOLEAN, "true") in field.options + + +def test_add_multiple_options(): + kit_kat = OptionElement("kit", OptionElement.Kind.STRING, "kat") + foo_bar = OptionElement("foo", OptionElement.Kind.STRING, "bar") + field = FieldElement( + location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1, options=[kit_kat, foo_bar] + ) + + assert len(field.options) == 2 + + +def test_default_is_set(): + field = FieldElement( + location=location, + label=Field.Label.REQUIRED, + element_type="string", + name="name", + tag=1, + default_value="defaultValue" + ) + + assert field.to_schema( + ) == trim_margin(""" + |required string name = 1 [default = "defaultValue"]; + |""") + + +def test_json_name_and_default_value(): + field = FieldElement( + location=location, + label=Field.Label.REQUIRED, + element_type="string", + name="name", + default_value="defaultValue", + json_name="my_json", + tag=1 + ) + + assert field.to_schema() == trim_margin( + """ + |required string name = 1 [ + | default = "defaultValue", + | json_name = "my_json" + |]; + |""" + ) + + +def test_json_name(): + field = FieldElement( + location=location, label=Field.Label.REQUIRED, element_type="string", name="name", json_name="my_json", tag=1 + ) + + assert field.to_schema( + ) == trim_margin(""" + |required string name = 1 [json_name = "my_json"]; + |""") diff --git a/tests/unit/test_message_element.py b/tests/unit/test_message_element.py new file mode 100644 index 000000000..324b434b6 --- /dev/null +++ b/tests/unit/test_message_element.py @@ -0,0 +1,448 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/MessageElementTest.kt + +from karapace.protobuf.extensions_element import ExtensionsElement +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.group_element import GroupElement +from karapace.protobuf.kotlin_wrapper import KotlinRange, trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.one_of_element import OneOfElement +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.reserved_element import ReservedElement + +location: Location = Location.get("file.proto") + + +def test_empty_to_schema(): + element = MessageElement(location=location, name="Message") + expected = "message Message {}\n" + assert element.to_schema() == expected + + +def test_simple_to_schema(): + element = MessageElement( + location=location, + name="Message", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)] + ) + expected = """ + |message Message { + | required string name = 1; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_fields(): + first_name = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="first_name", tag=1) + last_name = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="last_name", tag=2) + element = MessageElement(location=location, name="Message", fields=[first_name, last_name]) + assert len(element.fields) == 2 + + +def test_simple_with_documentation_to_schema(): + element = MessageElement( + location=location, + name="Message", + documentation="Hello", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)] + ) + expected = """ + |// Hello + |message Message { + | required string name = 1; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_simple_with_options_to_schema(): + field = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1) + element = MessageElement( + location=location, name="Message", fields=[field], options=[OptionElement("kit", OptionElement.Kind.STRING, "kat")] + ) + expected = """message Message { + | option kit = "kat"; + | + | required string name = 1; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_options(): + field = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1) + kit_kat = OptionElement("kit", OptionElement.Kind.STRING, "kat") + foo_bar = OptionElement("foo", OptionElement.Kind.STRING, "bar") + element = MessageElement(location=location, name="Message", fields=[field], options=[kit_kat, foo_bar]) + assert len(element.options) == 2 + + +def test_simple_with_nested_elements_to_schema(): + element = MessageElement( + location=location, + name="Message", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)], + nested_types=[ + MessageElement( + location=location, + name="Nested", + fields=[ + FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1) + ] + ) + ] + ) + expected = """ + |message Message { + | required string name = 1; + | + | message Nested { + | required string name = 1; + | } + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_types(): + nested1 = MessageElement(location=location, name="Nested1") + nested2 = MessageElement(location=location, name="Nested2") + element = MessageElement( + location=location, + name="Message", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)], + nested_types=[nested1, nested2] + ) + assert len(element.nested_types) == 2 + + +def test_simple_with_extensions_to_schema(): + element = MessageElement( + location=location, + name="Message", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)], + extensions=[ExtensionsElement(location=location, values=[KotlinRange(500, 501)])] + ) + expected = """ + |message Message { + | required string name = 1; + | + | extensions 500 to 501; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_extensions(): + fives = ExtensionsElement(location=location, values=[KotlinRange(500, 501)]) + sixes = ExtensionsElement(location=location, values=[KotlinRange(600, 601)]) + element = MessageElement( + location=location, + name="Message", + fields=[FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1)], + extensions=[fives, sixes] + ) + assert len(element.extensions) == 2 + + +def test_one_of_to_schema(): + element = MessageElement( + location=location, + name="Message", + one_ofs=[ + OneOfElement(name="hi", fields=[FieldElement(location=location, element_type="string", name="name", tag=1)]) + ] + ) + expected = """ + |message Message { + | oneof hi { + | string name = 1; + | } + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_one_of_with_group_to_schema(): + element = MessageElement( + location=location, + name="Message", + one_ofs=[ + OneOfElement( + name="hi", + fields=[FieldElement(location=location, element_type="string", name="name", tag=1)], + groups=[ + GroupElement( + location=location.at(5, 5), + name="Stuff", + tag=3, + label=None, + fields=[ + FieldElement( + location=location.at(6, 7), + label=Field.Label.OPTIONAL, + element_type="int32", + name="result_per_page", + tag=4 + ), + FieldElement( + location=location.at(7, 7), + label=Field.Label.OPTIONAL, + element_type="int32", + name="page_count", + tag=5 + ) + ] + ) + ] + ) + ] + ) + + expected = """ + |message Message { + | oneof hi { + | string name = 1; + | """ + \ + """ + | group Stuff = 3 { + | optional int32 result_per_page = 4; + | optional int32 page_count = 5; + | } + | } + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_add_multiple_one_ofs(): + hi = OneOfElement(name="hi", fields=[FieldElement(location=location, element_type="string", name="name", tag=1)]) + hey = OneOfElement(name="hey", fields=[FieldElement(location=location, element_type="string", name="city", tag=2)]) + element = MessageElement(location=location, name="Message", one_ofs=[hi, hey]) + assert len(element.one_ofs) == 2 + + +def test_reserved_to_schema(): + element = MessageElement( + location=location, + name="Message", + reserveds=[ + ReservedElement(location=location, values=[10, KotlinRange(12, 14), "foo"]), + ReservedElement(location=location, values=[10]), + ReservedElement(location=location, values=[KotlinRange(12, 14)]), + ReservedElement(location=location, values=["foo"]) + ] + ) + expected = """ + |message Message { + | reserved 10, 12 to 14, "foo"; + | reserved 10; + | reserved 12 to 14; + | reserved "foo"; + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_group_to_schema(): + element = MessageElement( + location=location.at(1, 1), + name="SearchResponse", + groups=[ + GroupElement( + location=location.at(2, 3), + label=Field.Label.REPEATED, + name="Result", + tag=1, + fields=[ + FieldElement( + location=location.at(3, 5), label=Field.Label.REQUIRED, element_type="string", name="url", tag=2 + ), + FieldElement( + location=location.at(4, 5), label=Field.Label.OPTIONAL, element_type="string", name="title", tag=3 + ), + FieldElement( + location=location.at(5, 5), + label=Field.Label.REPEATED, + element_type="string", + name="snippets", + tag=4 + ) + ] + ) + ] + ) + expected = """ + |message SearchResponse { + | repeated group Result = 1 { + | required string url = 2; + | optional string title = 3; + | repeated string snippets = 4; + | } + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_multiple_everything_to_schema(): + field1 = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1) + field2 = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="bool", name="other_name", tag=2) + one_off_1_field = FieldElement(location=location, element_type="string", name="namey", tag=3) + one_of_1 = OneOfElement(name="thingy", fields=[one_off_1_field]) + one_off_2_field = FieldElement(location=location, element_type="string", name="namer", tag=4) + one_of_2 = OneOfElement(name="thinger", fields=[one_off_2_field]) + extensions1 = ExtensionsElement(location=location, values=[KotlinRange(500, 501)]) + extensions2 = ExtensionsElement(location=location, values=[503]) + nested = MessageElement(location=location, name="Nested", fields=[field1]) + option = OptionElement("kit", OptionElement.Kind.STRING, "kat") + element = MessageElement( + location=location, + name="Message", + fields=[field1, field2], + one_ofs=[one_of_1, one_of_2], + nested_types=[nested], + extensions=[extensions1, extensions2], + options=[option] + ) + expected = """ + |message Message { + | option kit = "kat"; + | + | required string name = 1; + | + | required bool other_name = 2; + | + | oneof thingy { + | string namey = 3; + | } + | + | oneof thinger { + | string namer = 4; + | } + | + | extensions 500 to 501; + | extensions 503; + | + | message Nested { + | required string name = 1; + | } + |} + |""" + expected = trim_margin(expected) + assert element.to_schema() == expected + + +def test_field_to_schema(): + field = FieldElement(location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1) + expected = "required string name = 1;\n" + assert field.to_schema() == expected + + +def test_field_with_default_string_to_schema_in_proto2(): + field = FieldElement( + location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1, default_value="benoît" + ) + expected = "required string name = 1 [default = \"benoît\"];\n" + assert field.to_schema() == expected + + +def test_field_with_default_number_to_schema(): + field = FieldElement( + location=location, label=Field.Label.REQUIRED, element_type="int32", name="age", tag=1, default_value="34" + ) + expected = "required int32 age = 1 [default = 34];\n" + assert field.to_schema() == expected + + +def test_field_with_default_bool_to_schema(): + field = FieldElement( + location=location, label=Field.Label.REQUIRED, element_type="bool", name="human", tag=1, default_value="true" + ) + expected = "required bool human = 1 [default = true];\n" + assert field.to_schema() == expected + + +def test_one_of_field_to_schema(): + field = FieldElement(location=location, element_type="string", name="name", tag=1) + expected = "string name = 1;\n" + assert field.to_schema() == expected + + +def test_field_with_documentation_to_schema(): + field = FieldElement( + location=location, label=Field.Label.REQUIRED, element_type="string", name="name", tag=1, documentation="Hello" + ) + expected = """// Hello + |required string name = 1; + |""" + expected = trim_margin(expected) + assert field.to_schema() == expected + + +def test_field_with_one_option_to_schema(): + field = FieldElement( + location=location, + label=Field.Label.REQUIRED, + element_type="string", + name="name", + tag=1, + options=[OptionElement("kit", OptionElement.Kind.STRING, "kat")] + ) + expected = """required string name = 1 [kit = "kat"]; + |""" + expected = trim_margin(expected) + assert field.to_schema() == expected + + +def test_field_with_more_than_one_option_to_schema(): + field = FieldElement( + location=location, + label=Field.Label.REQUIRED, + element_type="string", + name="name", + tag=1, + options=[ + OptionElement("kit", OptionElement.Kind.STRING, "kat"), + OptionElement("dup", OptionElement.Kind.STRING, "lo") + ] + ) + expected = """required string name = 1 [ + | kit = "kat", + | dup = "lo" + |]; + |""" + expected = trim_margin(expected) + assert field.to_schema() == expected + + +def test_one_of_with_options(): + expected = """ + |oneof page_info { + | option (my_option) = true; + | + | int32 page_number = 2; + | int32 result_per_page = 3; + |} + |""" + expected = trim_margin(expected) + one_of = OneOfElement( + name="page_info", + fields=[ + FieldElement(location=location.at(4, 5), element_type="int32", name="page_number", tag=2), + FieldElement(location=location.at(5, 5), element_type="int32", name="result_per_page", tag=3) + ], + options=[OptionElement("my_option", OptionElement.Kind.BOOLEAN, "true", True)] + ) + assert one_of.to_schema() == expected diff --git a/tests/unit/test_option_element.py b/tests/unit/test_option_element.py new file mode 100644 index 000000000..782789c98 --- /dev/null +++ b/tests/unit/test_option_element.py @@ -0,0 +1,57 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/OptionElementTest.kt + +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.option_element import OptionElement + + +def test_simple_to_schema(): + option = OptionElement("foo", OptionElement.Kind.STRING, "bar") + expected = """foo = \"bar\"""" + assert option.to_schema() == expected + + +def test_nested_to_schema(): + option = OptionElement( + "foo.boo", OptionElement.Kind.OPTION, OptionElement("bar", OptionElement.Kind.STRING, "baz"), True + ) + expected = """(foo.boo).bar = \"baz\"""" + assert option.to_schema() == expected + + +def test_list_to_schema(): + option = OptionElement( + "foo", OptionElement.Kind.LIST, [ + OptionElement("ping", OptionElement.Kind.STRING, "pong", True), + OptionElement("kit", OptionElement.Kind.STRING, "kat") + ], True + ) + expected = """ + |(foo) = [ + | (ping) = "pong", + | kit = "kat" + |] + """ + expected = trim_margin(expected) + assert option.to_schema() == expected + + +def test_map_to_schema(): + option = OptionElement("foo", OptionElement.Kind.MAP, {"ping": "pong", "kit": ["kat", "kot"]}) + expected = """ + |foo = { + | ping: "pong", + | kit: [ + | "kat", + | "kot" + | ] + |} + """ + expected = trim_margin(expected) + assert option.to_schema() == expected + + +def test_boolean_to_schema(): + option = OptionElement("foo", OptionElement.Kind.BOOLEAN, "false") + expected = "foo = false" + assert option.to_schema() == expected diff --git a/tests/unit/test_parsing_tester.py b/tests/unit/test_parsing_tester.py new file mode 100644 index 000000000..cedd494c5 --- /dev/null +++ b/tests/unit/test_parsing_tester.py @@ -0,0 +1,31 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/ParsingTester.kt + +from karapace.protobuf.location import Location +from karapace.protobuf.proto_parser import ProtoParser + +import fnmatch +import os + +# Recursively traverse a directory and attempt to parse all of its proto files. + +# Directory under which to search for protos. Change as needed. +src = "test" + + +def test_multi_files(): + total = 0 + failed = 0 + + for root, dirnames, filenames in os.walk(src): # pylint: disable=W0612 + for filename in fnmatch.filter(filenames, '*.proto'): + fn = os.path.join(root, filename) + print(f"Parsing {fn}") + total += 1 + try: + data = open(fn).read() + ProtoParser.parse(Location.get(fn), data) + except Exception as e: # pylint: disable=broad-except + print(e) + failed += 1 + print(f"\nTotal: {total} Failed: {failed}") diff --git a/tests/unit/test_proto_file_element.py b/tests/unit/test_proto_file_element.py new file mode 100644 index 000000000..cdb4ee5c9 --- /dev/null +++ b/tests/unit/test_proto_file_element.py @@ -0,0 +1,469 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/ProtoFileElementTest.kt +from karapace.protobuf.extend_element import ExtendElement +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.option_element import OptionElement, PACKED_OPTION_ELEMENT +from karapace.protobuf.proto_file_element import ProtoFileElement +from karapace.protobuf.proto_parser import ProtoParser +from karapace.protobuf.service_element import ServiceElement +from karapace.protobuf.syntax import Syntax + +import copy + +location: Location = Location.get("some/folder", "file.proto") + + +def test_empty_to_schema(): + file = ProtoFileElement(location=location) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_empty_with_package_to_schema(): + file = ProtoFileElement(location=location, package_name="example.simple") + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |package example.simple; + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_simple_to_schema(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, types=[element]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |message Message {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_simple_with_imports_to_schema(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, imports=["example.other"], types=[element]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |import "example.other"; + | + |message Message {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_add_multiple_dependencies(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, imports=["example.other", "example.another"], types=[element]) + assert len(file.imports) == 2 + + +def test_simple_with_public_imports_to_schema(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, public_imports=["example.other"], types=[element]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |import public "example.other"; + | + |message Message {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_add_multiple_public_dependencies(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, public_imports=["example.other", "example.another"], types=[element]) + + assert len(file.public_imports) == 2 + + +def test_simple_with_both_imports_to_schema(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, imports=["example.thing"], public_imports=["example.other"], types=[element]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |import "example.thing"; + |import public "example.other"; + | + |message Message {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_simple_with_services_to_schema(): + element = MessageElement(location=location, name="Message") + service = ServiceElement(location=location, name="Service") + file = ProtoFileElement(location=location, types=[element], services=[service]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |message Message {} + | + |service Service {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_add_multiple_services(): + service1 = ServiceElement(location=location, name="Service1") + service2 = ServiceElement(location=location, name="Service2") + file = ProtoFileElement(location=location, services=[service1, service2]) + assert len(file.services) == 2 + + +def test_simple_with_options_to_schema(): + element = MessageElement(location=location, name="Message") + option = OptionElement("kit", OptionElement.Kind.STRING, "kat") + file = ProtoFileElement(location=location, options=[option], types=[element]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |option kit = "kat"; + | + |message Message {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_add_multiple_options(): + element = MessageElement(location=location, name="Message") + kit_kat = OptionElement("kit", OptionElement.Kind.STRING, "kat") + foo_bar = OptionElement("foo", OptionElement.Kind.STRING, "bar") + file = ProtoFileElement(location=location, options=[kit_kat, foo_bar], types=[element]) + assert len(file.options) == 2 + + +def test_simple_with_extends_to_schema(): + file = ProtoFileElement( + location=location, + extend_declarations=[ExtendElement(location=location.at(5, 1), name="Extend")], + types=[MessageElement(location=location, name="Message")] + ) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |message Message {} + | + |extend Extend {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_add_multiple_extends(): + extend1 = ExtendElement(location=location, name="Extend1") + extend2 = ExtendElement(location=location, name="Extend2") + file = ProtoFileElement(location=location, extend_declarations=[extend1, extend2]) + assert len(file.extend_declarations) == 2 + + +def test_multiple_everything_to_schema(): + element1 = MessageElement(location=location.at(12, 1), name="Message1") + element2 = MessageElement(location=location.at(14, 1), name="Message2") + extend1 = ExtendElement(location=location.at(16, 1), name="Extend1") + extend2 = ExtendElement(location=location.at(18, 1), name="Extend2") + option1 = OptionElement("kit", OptionElement.Kind.STRING, "kat") + option2 = OptionElement("foo", OptionElement.Kind.STRING, "bar") + service1 = ServiceElement(location=location.at(20, 1), name="Service1") + service2 = ServiceElement(location=location.at(22, 1), name="Service2") + file = ProtoFileElement( + location=location, + package_name="example.simple", + imports=["example.thing"], + public_imports=["example.other"], + types=[element1, element2], + services=[service1, service2], + extend_declarations=[extend1, extend2], + options=[option1, option2] + ) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |package example.simple; + | + |import "example.thing"; + |import public "example.other"; + | + |option kit = "kat"; + |option foo = "bar"; + | + |message Message1 {} + | + |message Message2 {} + | + |extend Extend1 {} + | + |extend Extend2 {} + | + |service Service1 {} + | + |service Service2 {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + # Re-parse the expected string into a ProtoFile and ensure they're equal. + parsed = ProtoParser.parse(location, expected) + assert parsed == file + + +def test_syntax_to_schema(): + element = MessageElement(location=location, name="Message") + file = ProtoFileElement(location=location, syntax=Syntax.PROTO_2, types=[element]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |syntax = "proto2"; + | + |message Message {} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + +def test_default_is_set_in_proto2(): + field = FieldElement( + location=location.at(12, 3), + label=Field.Label.REQUIRED, + element_type="string", + name="name", + tag=1, + default_value="defaultValue" + ) + message = MessageElement(location=location.at(11, 1), name="Message", fields=[field]) + file = ProtoFileElement( + syntax=Syntax.PROTO_2, + location=location, + package_name="example.simple", + imports=["example.thing"], + public_imports=["example.other"], + types=[message] + ) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |syntax = "proto2"; + | + |package example.simple; + | + |import "example.thing"; + |import public "example.other"; + | + |message Message { + | required string name = 1 [default = "defaultValue"]; + |} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + # Re-parse the expected string into a ProtoFile and ensure they're equal. + parsed = ProtoParser.parse(location, expected) + assert parsed == file + + +def test_convert_packed_option_from_wire_schema_in_proto2(): + field_numeric = FieldElement( + location=location.at(9, 3), + label=Field.Label.REPEATED, + element_type="int32", + name="numeric_without_packed_option", + tag=1 + ) + field_numeric_packed_true = FieldElement( + location=location.at(11, 3), + label=Field.Label.REPEATED, + element_type="int32", + name="numeric_packed_true", + tag=2, + options=[PACKED_OPTION_ELEMENT] + ) + el = copy.copy(PACKED_OPTION_ELEMENT) + el.value = "false" + field_numeric_packed_false = FieldElement( + location=location.at(13, 3), + label=Field.Label.REPEATED, + element_type="int32", + name="numeric_packed_false", + tag=3, + options=[el] + ) + field_string = FieldElement( + location=location.at(15, 3), + label=Field.Label.REPEATED, + element_type="string", + name="string_without_packed_option", + tag=4 + ) + field_string_packed_true = FieldElement( + location=location.at(17, 3), + label=Field.Label.REPEATED, + element_type="string", + name="string_packed_true", + tag=5, + options=[PACKED_OPTION_ELEMENT] + ) + field_string_packed_false = FieldElement( + location=location.at(19, 3), + label=Field.Label.REPEATED, + element_type="string", + name="string_packed_false", + tag=6, + options=[el] + ) + + message = MessageElement( + location=location.at(8, 1), + name="Message", + fields=[ + field_numeric, field_numeric_packed_true, field_numeric_packed_false, field_string, field_string_packed_true, + field_string_packed_false + ] + ) + file = ProtoFileElement(syntax=Syntax.PROTO_2, location=location, package_name="example.simple", types=[message]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |syntax = "proto2"; + | + |package example.simple; + | + |message Message { + | repeated int32 numeric_without_packed_option = 1; + | + | repeated int32 numeric_packed_true = 2 [packed = true]; + | + | repeated int32 numeric_packed_false = 3 [packed = false]; + | + | repeated string string_without_packed_option = 4; + | + | repeated string string_packed_true = 5 [packed = true]; + | + | repeated string string_packed_false = 6 [packed = false]; + |} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + # Re-parse the expected string into a ProtoFile and ensure they're equal. + parsed = ProtoParser.parse(location, expected) + assert parsed == file + + +def test_convert_packed_option_from_wire_schema_in_proto3(): + field_numeric = FieldElement( + location=location.at(9, 3), + label=Field.Label.REPEATED, + element_type="int32", + name="numeric_without_packed_option", + tag=1 + ) + field_numeric_packed_true = FieldElement( + location=location.at(11, 3), + label=Field.Label.REPEATED, + element_type="int32", + name="numeric_packed_true", + tag=2, + options=[PACKED_OPTION_ELEMENT] + ) + el = copy.copy(PACKED_OPTION_ELEMENT) + el.value = "false" + field_numeric_packed_false = FieldElement( + location=location.at(13, 3), + label=Field.Label.REPEATED, + element_type="int32", + name="numeric_packed_false", + tag=3, + options=[el] + ) + field_string = FieldElement( + location=location.at(15, 3), + label=Field.Label.REPEATED, + element_type="string", + name="string_without_packed_option", + tag=4 + ) + field_string_packed_true = FieldElement( + location=location.at(17, 3), + label=Field.Label.REPEATED, + element_type="string", + name="string_packed_true", + tag=5, + options=[PACKED_OPTION_ELEMENT] + ) + + field_string_packed_false = FieldElement( + location=location.at(19, 3), + label=Field.Label.REPEATED, + element_type="string", + name="string_packed_false", + tag=6, + options=[el] + ) + + message = MessageElement( + location=location.at(8, 1), + name="Message", + fields=[ + field_numeric, field_numeric_packed_true, field_numeric_packed_false, field_string, field_string_packed_true, + field_string_packed_false + ] + ) + file = ProtoFileElement(syntax=Syntax.PROTO_3, location=location, package_name="example.simple", types=[message]) + expected = """ + |// Proto schema formatted by Wire, do not edit. + |// Source: file.proto + | + |syntax = "proto3"; + | + |package example.simple; + | + |message Message { + | repeated int32 numeric_without_packed_option = 1; + | + | repeated int32 numeric_packed_true = 2 [packed = true]; + | + | repeated int32 numeric_packed_false = 3 [packed = false]; + | + | repeated string string_without_packed_option = 4; + | + | repeated string string_packed_true = 5 [packed = true]; + | + | repeated string string_packed_false = 6 [packed = false]; + |} + |""" + expected = trim_margin(expected) + assert file.to_schema() == expected + + # Re-parse the expected string into a ProtoFile and ensure they're equal. + parsed = ProtoParser.parse(location, expected) + assert parsed == file diff --git a/tests/unit/test_proto_parser.py b/tests/unit/test_proto_parser.py new file mode 100644 index 000000000..6774ef91c --- /dev/null +++ b/tests/unit/test_proto_parser.py @@ -0,0 +1,2610 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/ProtoParserTest.kt + +from karapace.protobuf.enum_constant_element import EnumConstantElement +from karapace.protobuf.enum_element import EnumElement +from karapace.protobuf.exception import IllegalStateException +from karapace.protobuf.extend_element import ExtendElement +from karapace.protobuf.extensions_element import ExtensionsElement +from karapace.protobuf.field import Field +from karapace.protobuf.field_element import FieldElement +from karapace.protobuf.group_element import GroupElement +from karapace.protobuf.kotlin_wrapper import KotlinRange, trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.message_element import MessageElement +from karapace.protobuf.one_of_element import OneOfElement +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.proto_file_element import ProtoFileElement +from karapace.protobuf.proto_parser import ProtoParser +from karapace.protobuf.reserved_element import ReservedElement +from karapace.protobuf.rpc_element import RpcElement +from karapace.protobuf.service_element import ServiceElement +from karapace.protobuf.syntax import Syntax +from karapace.protobuf.utils import MAX_TAG_VALUE + +import pytest + +location: Location = Location.get("file.proto") + + +def test_type_parsing(): + proto = """ + |message Types { + | required any f1 = 1; + | required bool f2 = 2; + | required bytes f3 = 3; + | required double f4 = 4; + | required float f5 = 5; + | required fixed32 f6 = 6; + | required fixed64 f7 = 7; + | required int32 f8 = 8; + | required int64 f9 = 9; + | required sfixed32 f10 = 10; + | required sfixed64 f11 = 11; + | required sint32 f12 = 12; + | required sint64 f13 = 13; + | required string f14 = 14; + | required uint32 f15 = 15; + | required uint64 f16 = 16; + | map f17 = 17; + | map f18 = 18; + | required arbitrary f19 = 19; + | required nested.nested f20 = 20; + |} + """ + proto = trim_margin(proto) + + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="Types", + fields=[ + FieldElement( + location=location.at(2, 3), label=Field.Label.REQUIRED, element_type="any", name="f1", tag=1 + ), + FieldElement( + location=location.at(3, 3), label=Field.Label.REQUIRED, element_type="bool", name="f2", tag=2 + ), + FieldElement( + location=location.at(4, 3), label=Field.Label.REQUIRED, element_type="bytes", name="f3", tag=3 + ), + FieldElement( + location=location.at(5, 3), label=Field.Label.REQUIRED, element_type="double", name="f4", tag=4 + ), + FieldElement( + location=location.at(6, 3), label=Field.Label.REQUIRED, element_type="float", name="f5", tag=5 + ), + FieldElement( + location=location.at(7, 3), label=Field.Label.REQUIRED, element_type="fixed32", name="f6", tag=6 + ), + FieldElement( + location=location.at(8, 3), label=Field.Label.REQUIRED, element_type="fixed64", name="f7", tag=7 + ), + FieldElement( + location=location.at(9, 3), label=Field.Label.REQUIRED, element_type="int32", name="f8", tag=8 + ), + FieldElement( + location=location.at(10, 3), label=Field.Label.REQUIRED, element_type="int64", name="f9", tag=9 + ), + FieldElement( + location=location.at(11, 3), label=Field.Label.REQUIRED, element_type="sfixed32", name="f10", tag=10 + ), + FieldElement( + location=location.at(12, 3), label=Field.Label.REQUIRED, element_type="sfixed64", name="f11", tag=11 + ), + FieldElement( + location=location.at(13, 3), label=Field.Label.REQUIRED, element_type="sint32", name="f12", tag=12 + ), + FieldElement( + location=location.at(14, 3), label=Field.Label.REQUIRED, element_type="sint64", name="f13", tag=13 + ), + FieldElement( + location=location.at(15, 3), label=Field.Label.REQUIRED, element_type="string", name="f14", tag=14 + ), + FieldElement( + location=location.at(16, 3), label=Field.Label.REQUIRED, element_type="uint32", name="f15", tag=15 + ), + FieldElement( + location=location.at(17, 3), label=Field.Label.REQUIRED, element_type="uint64", name="f16", tag=16 + ), + FieldElement(location=location.at(18, 3), element_type="map", name="f17", tag=17), + FieldElement( + location=location.at(19, 3), element_type="map", name="f18", tag=18 + ), + FieldElement( + location=location.at(20, 3), + label=Field.Label.REQUIRED, + element_type="arbitrary", + name="f19", + tag=19 + ), + FieldElement( + location=location.at(21, 3), + label=Field.Label.REQUIRED, + element_type="nested.nested", + name="f20", + tag=20 + ) + ] + ) + ] + ) + my = ProtoParser.parse(location, proto) + assert my == expected + + +def test_map_with_label_throws(): + with pytest.raises(IllegalStateException, match="Syntax error in file.proto:1:15: 'map' type cannot have label"): + ProtoParser.parse(location, "message Hey { required map a = 1; }") + pytest.fail("") + + with pytest.raises(IllegalStateException, match="Syntax error in file.proto:1:15: 'map' type cannot have label"): + ProtoParser.parse(location, "message Hey { optional map a = 1; }") + pytest.fail("") + + with pytest.raises(IllegalStateException, match="Syntax error in file.proto:1:15: 'map' type cannot have label"): + ProtoParser.parse(location, "message Hey { repeated map a = 1; }") + pytest.fail("") + + +def test_default_field_option_is_special(): + """ It looks like an option, but 'default' is special. It's not defined as an option. + """ + proto = """ + |message Message { + | required string a = 1 [default = "b", faulted = "c"]; + |} + |""" + + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="Message", + fields=[ + FieldElement( + location=location.at(2, 3), + label=Field.Label.REQUIRED, + element_type="string", + name="a", + default_value="b", + options=[OptionElement("faulted", OptionElement.Kind.STRING, "c")], + tag=1 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_json_name_option_is_special(): + """ It looks like an option, but 'json_name' is special. It's not defined as an option. + """ + proto = """ + |message Message { + | required string a = 1 [json_name = "b", faulted = "c"]; + |} + |""" + proto = trim_margin(proto) + + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="Message", + fields=[ + FieldElement( + location=location.at(2, 3), + label=Field.Label.REQUIRED, + element_type="string", + name="a", + json_name="b", + tag=1, + options=[OptionElement("faulted", OptionElement.Kind.STRING, "c")] + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_single_line_comment(): + proto = """ + |// Test all the things! + |message Test {} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == "Test all the things!" + + +def test_multiple_single_line_comments(): + proto = """ + |// Test all + |// the things! + |message Test {} + """ + proto = trim_margin(proto) + expected = """ + |Test all + |the things! + """ + expected = trim_margin(expected) + + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == expected + + +def test_single_line_javadoc_comment(): + proto = """ + |/** Test */ + |message Test {} + |""" + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == "Test" + + +def test_multiline_javadoc_comment(): + proto = """ + |/** + | * Test + | * + | * Foo + | */ + |message Test {} + |""" + proto = trim_margin(proto) + expected = """ + |Test + | + |Foo + """ + expected = trim_margin(expected) + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == expected + + +def test_multiple_single_line_comments_with_leading_whitespace(): + proto = """ + |// Test + |// All + |// The + |// Things! + |message Test {} + """ + proto = trim_margin(proto) + expected = """ + |Test + | All + | The + | Things! + """ + expected = trim_margin(expected) + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == expected + + +def test_multiline_javadoc_comment_with_leading_whitespace(): + proto = """ + |/** + | * Test + | * All + | * The + | * Things! + | */ + |message Test {} + """ + proto = trim_margin(proto) + expected = """ + |Test + | All + | The + | Things! + """ + expected = trim_margin(expected) + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == expected + + +def test_multiline_javadoc_comment_without_leading_asterisks(): + # We do not honor leading whitespace when the comment lacks leading asterisks. + proto = """ + |/** + | Test + | All + | The + | Things! + | */ + |message Test {} + """ + proto = trim_margin(proto) + expected = """ + |Test + |All + |The + |Things! + """ + expected = trim_margin(expected) + parsed = ProtoParser.parse(location, proto) + element_type = parsed.types[0] + assert element_type.documentation == expected + + +def test_message_field_trailing_comment(): + # Trailing message field comment. + proto = """ + |message Test { + | optional string name = 1; // Test all the things! + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + message: MessageElement = parsed.types[0] + field = message.fields[0] + assert field.documentation == "Test all the things!" + + +def test_message_field_leading_and_trailing_comment_are_combined(): + proto = """ + |message Test { + | // Test all... + | optional string name = 1; // ...the things! + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + message: MessageElement = parsed.types[0] + field = message.fields[0] + assert field.documentation == "Test all...\n...the things!" + + +def test_trailing_comment_not_assigned_to_following_field(): + proto = """ + |message Test { + | optional string first_name = 1; // Testing! + | optional string last_name = 2; + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + message: MessageElement = parsed.types[0] + field1 = message.fields[0] + assert field1.documentation == "Testing!" + field2 = message.fields[1] + assert field2.documentation == "" + + +def test_enum_value_trailing_comment(): + proto = """ + |enum Test { + | FOO = 1; // Test all the things! + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + enum_element: EnumElement = parsed.types[0] + value = enum_element.constants[0] + assert value.documentation == "Test all the things!" + + +def test_trailing_singleline_comment(): + proto = """ + |enum Test { + | FOO = 1; /* Test all the things! */ + | BAR = 2;/*Test all the things!*/ + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + enum_element: EnumElement = parsed.types[0] + c_foo = enum_element.constants[0] + assert c_foo.documentation == "Test all the things!" + c_bar = enum_element.constants[1] + assert c_bar.documentation == "Test all the things!" + + +def test_trailing_multiline_comment(): + proto = """ + |enum Test { + | FOO = 1; /* Test all the + |things! */ + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + enum_element: EnumElement = parsed.types[0] + value = enum_element.constants[0] + assert value.documentation == "Test all the\nthings!" + + +def test_trailing_multiline_comment_must_be_last_on_line_throws(): + proto = """ + |enum Test { + | FOO = 1; /* Test all the things! */ BAR = 2; + |} + """ + proto = trim_margin(proto) + with pytest.raises( + IllegalStateException, match="Syntax error in file.proto:2:40: no syntax may follow trailing comment" + ): + ProtoParser.parse(location, proto) + pytest.fail("") + + +def test_invalid_trailing_comment(): + proto = """ + |enum Test { + | FOO = 1; / + |} + """ + proto = trim_margin(proto) + # try : + # ProtoParser.parse(location, proto) + # except IllegalStateException as e : + # if e.message != "Syntax error in file.proto:2:12: expected '//' or '/*'" : + # pytest.fail("") + + with pytest.raises(IllegalStateException) as re: + # TODO: this test in Kotlin source contains "2:13:" Need compile square.wire and check how it can be? + + ProtoParser.parse(location, proto) + pytest.fail("") + assert re.value.message == "Syntax error in file.proto:2:12: expected '//' or '/*'" + + +def test_enum_value_leading_and_trailing_comments_are_combined(): + proto = """ + |enum Test { + | // Test all... + | FOO = 1; // ...the things! + |} + """ + proto = trim_margin(proto) + parsed = ProtoParser.parse(location, proto) + enum_element: EnumElement = parsed.types[0] + value = enum_element.constants[0] + assert value.documentation == "Test all...\n...the things!" + + +def test_trailing_comment_not_combined_when_empty(): + """ (Kotlin) Can't use raw strings here; otherwise, the formatter removes the trailing whitespace on line 3. """ + proto = "enum Test {\n" \ + " // Test all...\n" \ + " FOO = 1; // \n" \ + "}" + parsed = ProtoParser.parse(location, proto) + enum_element: EnumElement = parsed.types[0] + value = enum_element.constants[0] + assert value.documentation == "Test all..." + + +def test_syntax_not_required(): + proto = "message Foo {}" + parsed = ProtoParser.parse(location, proto) + assert parsed.syntax is None + + +def test_syntax_specified(): + proto = """ + |syntax = "proto3"; + |message Foo {} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, syntax=Syntax.PROTO_3, types=[MessageElement(location=location.at(2, 1), name="Foo")] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_invalid_syntax_value_throws(): + proto = """ + |syntax = "proto4"; + |message Foo {} + """ + proto = trim_margin(proto) + with pytest.raises(IllegalStateException, match="Syntax error in file.proto:1:1: unexpected syntax: proto4"): + ProtoParser.parse(location, proto) + pytest.fail("") + + +def test_syntax_not_first_declaration_throws(): + proto = """ + |message Foo {} + |syntax = "proto3"; + """ + proto = trim_margin(proto) + with pytest.raises( + IllegalStateException, + match="Syntax error in file.proto:2:1: 'syntax' element must be the first declaration " + "in a file" + ): + ProtoParser.parse(location, proto) + pytest.fail("") + + +def test_syntax_may_follow_comments_and_empty_lines(): + proto = """ + |/* comment 1 */ + |// comment 2 + | + |syntax = "proto3"; + |message Foo {} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, syntax=Syntax.PROTO_3, types=[MessageElement(location=location.at(5, 1), name="Foo")] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto3_message_fields_do_not_require_labels(): + proto = """ + |syntax = "proto3"; + |message Message { + | string a = 1; + | int32 b = 2; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + syntax=Syntax.PROTO_3, + types=[ + MessageElement( + location=location.at(2, 1), + name="Message", + fields=[ + FieldElement(location=location.at(3, 3), element_type="string", name="a", tag=1), + FieldElement(location=location.at(4, 3), element_type="int32", name="b", tag=2) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto3_extension_fields_do_not_require_labels(): + proto = """ + |syntax = "proto3"; + |message Message { + |} + |extend Message { + | string a = 1; + | int32 b = 2; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + syntax=Syntax.PROTO_3, + types=[MessageElement(location=location.at(2, 1), name="Message")], + extend_declarations=[ + ExtendElement( + location=location.at(4, 1), + name="Message", + fields=[ + FieldElement(location=location.at(5, 3), element_type="string", name="a", tag=1), + FieldElement(location=location.at(6, 3), element_type="int32", name="b", tag=2) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto3_message_fields_allow_optional(): + proto = """ + |syntax = "proto3"; + |message Message { + | optional string a = 1; + |} + """ + proto = trim_margin(proto) + + expected = ProtoFileElement( + location=location, + syntax=Syntax.PROTO_3, + types=[ + MessageElement( + location=location.at(2, 1), + name="Message", + fields=[ + FieldElement( + location=location.at(3, 3), element_type="string", name="a", tag=1, label=Field.Label.OPTIONAL + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto3_message_fields_forbid_required(): + proto = """ + |syntax = "proto3"; + |message Message { + | required string a = 1; + |} + """ + proto = trim_margin(proto) + with pytest.raises( + IllegalStateException, + match="Syntax error in file.proto:3:3: 'required' label forbidden in proto3 field " + "declarations" + ): + ProtoParser.parse(location, proto) + pytest.fail("") + + +def test_proto3_extension_fields_allow_optional(): + proto = """ + |syntax = "proto3"; + |message Message { + |} + |extend Message { + | optional string a = 1; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + syntax=Syntax.PROTO_3, + types=[MessageElement(location=location.at(2, 1), name="Message")], + extend_declarations=[ + ExtendElement( + location=location.at(4, 1), + name="Message", + fields=[ + FieldElement( + location=location.at(5, 3), element_type="string", name="a", tag=1, label=Field.Label.OPTIONAL + ) + ], + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto3_extension_fields_forbids_required(): + proto = """ + |syntax = "proto3"; + |message Message { + |} + |extend Message { + | required string a = 1; + |} + """ + proto = trim_margin(proto) + with pytest.raises( + IllegalStateException, + match="Syntax error in file.proto:5:3: 'required' label forbidden in proto3 field " + "declarations" + ): + ProtoParser.parse(location, proto) + pytest.fail("") + + +def test_proto3_message_fields_permit_repeated(): + proto = """ + |syntax = "proto3"; + |message Message { + | repeated string a = 1; + |} + """ + proto = trim_margin(proto) + + expected = ProtoFileElement( + location=location, + syntax=Syntax.PROTO_3, + types=[ + MessageElement( + location=location.at(2, 1), + name="Message", + fields=[ + FieldElement( + location=location.at(3, 3), label=Field.Label.REPEATED, element_type="string", name="a", tag=1 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto3_extension_fields_permit_repeated(): + proto = """ + |syntax = "proto3"; + |message Message { + |} + |extend Message { + | repeated string a = 1; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + syntax=Syntax.PROTO_3, + types=[MessageElement(location=location.at(2, 1), name="Message")], + extend_declarations=[ + ExtendElement( + location=location.at(4, 1), + name="Message", + fields=[ + FieldElement( + location=location.at(5, 3), label=Field.Label.REPEATED, element_type="string", name="a", tag=1 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_parse_message_and_fields(): + proto = """ + |message SearchRequest { + | required string query = 1; + | optional int32 page_number = 2; + | optional int32 result_per_page = 3; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="SearchRequest", + fields=[ + FieldElement( + location=location.at(2, 3), label=Field.Label.REQUIRED, element_type="string", name="query", tag=1 + ), + FieldElement( + location=location.at(3, 3), + label=Field.Label.OPTIONAL, + element_type="int32", + name="page_number", + tag=2 + ), + FieldElement( + location=location.at(4, 3), + label=Field.Label.OPTIONAL, + element_type="int32", + name="result_per_page", + tag=3 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_group(): + proto = """ + |message SearchResponse { + | repeated group Result = 1 { + | required string url = 2; + | optional string title = 3; + | repeated string snippets = 4; + | } + |} + """ + proto = trim_margin(proto) + message = MessageElement( + location=location.at(1, 1), + name="SearchResponse", + groups=[ + GroupElement( + location=location.at(2, 3), + label=Field.Label.REPEATED, + name="Result", + tag=1, + fields=[ + FieldElement( + location=location.at(3, 5), label=Field.Label.REQUIRED, element_type="string", name="url", tag=2 + ), + FieldElement( + location=location.at(4, 5), label=Field.Label.OPTIONAL, element_type="string", name="title", tag=3 + ), + FieldElement( + location=location.at(5, 5), + label=Field.Label.REPEATED, + element_type="string", + name="snippets", + tag=4 + ) + ] + ) + ] + ) + expected = ProtoFileElement(location=location, types=[message]) + assert ProtoParser.parse(location, proto) == expected + + +def test_parse_message_and_one_of(): + proto = """ + |message SearchRequest { + | required string query = 1; + | oneof page_info { + | int32 page_number = 2; + | int32 result_per_page = 3; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="SearchRequest", + fields=[ + FieldElement( + location=location.at(2, 3), label=Field.Label.REQUIRED, element_type="string", name="query", tag=1 + ) + ], + one_ofs=[ + OneOfElement( + name="page_info", + fields=[ + FieldElement(location=location.at(4, 5), element_type="int32", name="page_number", tag=2), + FieldElement(location=location.at(5, 5), element_type="int32", name="result_per_page", tag=3) + ], + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_parse_message_and_one_of_with_group(): + proto = """ + |message SearchRequest { + | required string query = 1; + | oneof page_info { + | int32 page_number = 2; + | group Stuff = 3 { + | optional int32 result_per_page = 4; + | optional int32 page_count = 5; + | } + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="SearchRequest", + fields=[ + FieldElement( + location=location.at(2, 3), label=Field.Label.REQUIRED, element_type="string", name="query", tag=1 + ) + ], + one_ofs=[ + OneOfElement( + name="page_info", + fields=[FieldElement(location=location.at(4, 5), element_type="int32", name="page_number", tag=2)], + groups=[ + GroupElement( + label=None, + location=location.at(5, 5), + name="Stuff", + tag=3, + fields=[ + FieldElement( + location=location.at(6, 7), + label=Field.Label.OPTIONAL, + element_type="int32", + name="result_per_page", + tag=4 + ), + FieldElement( + location=location.at(7, 7), + label=Field.Label.OPTIONAL, + element_type="int32", + name="page_count", + tag=5 + ) + ] + ) + ], + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_parse_enum(): + proto = """ + |/** + | * What's on my waffles. + | * Also works on pancakes. + | */ + |enum Topping { + | FRUIT = 1; + | /** Yummy, yummy cream. */ + | CREAM = 2; + | + | // Quebec Maple syrup + | SYRUP = 3; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + EnumElement( + location=location.at(5, 1), + name="Topping", + documentation="What's on my waffles.\nAlso works on pancakes.", + constants=[ + EnumConstantElement(location=location.at(6, 3), name="FRUIT", tag=1), + EnumConstantElement( + location=location.at(8, 3), + name="CREAM", + tag=2, + documentation="Yummy, yummy cream.", + ), + EnumConstantElement( + location=location.at(11, 3), + name="SYRUP", + tag=3, + documentation="Quebec Maple syrup", + ) + ], + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_parse_enum_with_options(): + proto = """ + |/** + | * What's on my waffles. + | * Also works on pancakes. + | */ + |enum Topping { + | option(max_choices) = 2; + | + | FRUIT = 1[(healthy) = true]; + | /** Yummy, yummy cream. */ + | CREAM = 2; + | + | // Quebec Maple syrup + | SYRUP = 3; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + EnumElement( + location=location.at(5, 1), + name="Topping", + documentation="What's on my waffles.\nAlso works on pancakes.", + options=[OptionElement("max_choices", OptionElement.Kind.NUMBER, "2", True)], + constants=[ + EnumConstantElement( + location=location.at(8, 3), + name="FRUIT", + tag=1, + options=[OptionElement("healthy", OptionElement.Kind.BOOLEAN, "true", True)] + ), + EnumConstantElement( + location=location.at(10, 3), + name="CREAM", + tag=2, + documentation="Yummy, yummy cream.", + ), + EnumConstantElement( + location=location.at(13, 3), + name="SYRUP", + tag=3, + documentation="Quebec Maple syrup", + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_package_declaration(): + proto = """ + |package google.protobuf; + |option java_package = "com.google.protobuf"; + | + |// The protocol compiler can output a FileDescriptorSet containing the .proto + |// files it parses. + |message FileDescriptorSet { + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + package_name="google.protobuf", + types=[ + MessageElement( + location=location.at(6, 1), + name="FileDescriptorSet", + documentation="The protocol compiler can output a FileDescriptorSet containing the .proto\nfiles " + "it parses." + ) + ], + options=[OptionElement("java_package", OptionElement.Kind.STRING, "com.google.protobuf")] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_nesting_in_message(): + proto = """ + |message FieldOptions { + | optional CType ctype = 1[old_default = STRING, deprecated = true]; + | enum CType { + | STRING = 0[(opt_a) = 1, (opt_b) = 2]; + | }; + | // Clients can define custom options in extensions of this message. See above. + | extensions 500; + | extensions 1000 to max; + |} + """ + proto = trim_margin(proto) + enum_element = EnumElement( + location=location.at(3, 3), + name="CType", + constants=[ + EnumConstantElement( + location=location.at(4, 5), + name="STRING", + tag=0, + options=[ + OptionElement("opt_a", OptionElement.Kind.NUMBER, "1", True), + OptionElement("opt_b", OptionElement.Kind.NUMBER, "2", True) + ] + ) + ], + ) + field = FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="CType", + name="ctype", + tag=1, + options=[ + OptionElement("old_default", OptionElement.Kind.ENUM, "STRING"), + OptionElement("deprecated", OptionElement.Kind.BOOLEAN, "true") + ] + ) + + assert len(field.options) == 2 + assert OptionElement("old_default", OptionElement.Kind.ENUM, "STRING") in field.options + assert OptionElement("deprecated", OptionElement.Kind.BOOLEAN, "true") in field.options + + message_element = MessageElement( + location=location.at(1, 1), + name="FieldOptions", + fields=[field], + nested_types=[enum_element], + extensions=[ + ExtensionsElement( + location=location.at(7, 3), + documentation="Clients can define custom options in extensions of this message. See above.", + values=[500] + ), + ExtensionsElement(location.at(8, 3), "", [KotlinRange(1000, MAX_TAG_VALUE)]) + ] + ) + expected = ProtoFileElement(location=location, types=[message_element]) + actual = ProtoParser.parse(location, proto) + assert actual == expected + + +def test_multi_ranges_extensions(): + proto = """ + |message MeGustaExtensions { + | extensions 1, 5 to 200, 500, 1000 to max; + |} + """ + proto = trim_margin(proto) + message_element = MessageElement( + location=location.at(1, 1), + name="MeGustaExtensions", + extensions=[ + ExtensionsElement( + location=location.at(2, 3), values=[1] + [KotlinRange(5, 200)] + [500] + [KotlinRange(1000, MAX_TAG_VALUE)] + ) + ] + ) + expected = ProtoFileElement(location=location, types=[message_element]) + actual = ProtoParser.parse(location, proto) + assert actual == expected + + +def test_option_parentheses(): + proto = """ + |message Chickens { + | optional bool koka_ko_koka_ko = 1[old_default = true]; + | optional bool coodle_doodle_do = 2[(delay) = 100, old_default = false]; + | optional bool coo_coo_ca_cha = 3[old_default = true, (delay) = 200]; + | optional bool cha_chee_cha = 4; + |} + """ + proto = trim_margin(proto) + + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="Chickens", + fields=[ + FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="koka_ko_koka_ko", + tag=1, + options=[OptionElement("old_default", OptionElement.Kind.BOOLEAN, "true")] + ), + FieldElement( + location=location.at(3, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="coodle_doodle_do", + tag=2, + options=[ + OptionElement("delay", OptionElement.Kind.NUMBER, "100", True), + OptionElement("old_default", OptionElement.Kind.BOOLEAN, "false") + ] + ), + FieldElement( + location=location.at(4, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="coo_coo_ca_cha", + tag=3, + options=[ + OptionElement("old_default", OptionElement.Kind.BOOLEAN, "true"), + OptionElement("delay", OptionElement.Kind.NUMBER, "200", True) + ] + ), + FieldElement( + location=location.at(5, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="cha_chee_cha", + tag=4 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_imports(): + proto = "import \"src/test/resources/unittest_import.proto\";\n" + expected = ProtoFileElement(location=location, imports=["src/test/resources/unittest_import.proto"]) + assert ProtoParser.parse(location, proto) == expected + + +def test_public_imports(): + proto = "import public \"src/test/resources/unittest_import.proto\";\n" + expected = ProtoFileElement(location=location, public_imports=["src/test/resources/unittest_import.proto"]) + assert ProtoParser.parse(location, proto) == expected + + +def test_extend(): + proto = """ + |// Extends Foo + |extend Foo { + | optional int32 bar = 126; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + extend_declarations=[ + ExtendElement( + location=location.at(2, 1), + name="Foo", + documentation="Extends Foo", + fields=[ + FieldElement( + location=location.at(3, 3), label=Field.Label.OPTIONAL, element_type="int32", name="bar", tag=126 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_extend_in_message(): + proto = """ + |message Bar { + | extend Foo { + | optional Bar bar = 126; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[MessageElement(location=location.at(1, 1), name="Bar")], + extend_declarations=[ + ExtendElement( + location=location.at(2, 3), + name="Foo", + fields=[ + FieldElement( + location=location.at(3, 5), label=Field.Label.OPTIONAL, element_type="Bar", name="bar", tag=126 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_extend_in_message_with_package(): + proto = """ + |package kit.kat; + | + |message Bar { + | extend Foo { + | optional Bar bar = 126; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + package_name="kit.kat", + types=[MessageElement(location=location.at(3, 1), name="Bar")], + extend_declarations=[ + ExtendElement( + location=location.at(4, 3), + name="Foo", + fields=[ + FieldElement( + location=location.at(5, 5), label=Field.Label.OPTIONAL, element_type="Bar", name="bar", tag=126 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_fqcn_extend_in_message(): + proto = """ + |message Bar { + | extend example.Foo { + | optional Bar bar = 126; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[MessageElement(location=location.at(1, 1), name="Bar")], + extend_declarations=[ + ExtendElement( + location=location.at(2, 3), + name="example.Foo", + fields=[ + FieldElement( + location=location.at(3, 5), label=Field.Label.OPTIONAL, element_type="Bar", name="bar", tag=126 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_fqcn_extend_in_message_with_package(): + proto = """ + |package kit.kat; + | + |message Bar { + | extend example.Foo { + | optional Bar bar = 126; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + package_name="kit.kat", + types=[MessageElement(location=location.at(3, 1), name="Bar")], + extend_declarations=[ + ExtendElement( + location=location.at(4, 3), + name="example.Foo", + fields=[ + FieldElement( + location=location.at(5, 5), label=Field.Label.OPTIONAL, element_type="Bar", name="bar", tag=126 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_default_field_with_paren(): + proto = """ + |message Foo { + | optional string claim_token = 2[(squareup.redacted) = true]; + |} + """ + proto = trim_margin(proto) + field = FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="claim_token", + tag=2, + options=[OptionElement("squareup.redacted", OptionElement.Kind.BOOLEAN, "true", True)] + ) + assert len(field.options) == 1 + assert OptionElement("squareup.redacted", OptionElement.Kind.BOOLEAN, "true", True) in field.options + + message_element = MessageElement(location=location.at(1, 1), name="Foo", fields=[field]) + expected = ProtoFileElement(location=location, types=[message_element]) + assert ProtoParser.parse(location, proto) == expected + + +# Parse \a, \b, \f, \n, \r, \t, \v, \[0-7]{1-3}, and \[xX]{0-9a-fA-F]{1,2} +def test_default_field_with_string_escapes(): + proto = r""" + |message Foo { + | optional string name = 1 [ + | x = "\a\b\f\n\r\t\v\1f\01\001\11\011\111\xe\Xe\xE\xE\x41\x41" + | ]; + |} + """ + proto = trim_margin(proto) + field = FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="name", + tag=1, + options=[ + OptionElement( + "x", OptionElement.Kind.STRING, + "\u0007\b\u000C\n\r\t\u000b\u0001f\u0001\u0001\u0009\u0009I\u000e\u000e\u000e\u000eAA" + ) + ] + ) + assert len(field.options) == 1 + assert OptionElement( + "x", OptionElement.Kind.STRING, + "\u0007\b\u000C\n\r\t\u000b\u0001f\u0001\u0001\u0009\u0009I\u000e\u000e\u000e\u000eAA" + ) in field.options + + message_element = MessageElement(location=location.at(1, 1), name="Foo", fields=[field]) + expected = ProtoFileElement(location=location, types=[message_element]) + assert ProtoParser.parse(location, proto) == expected + + +def test_string_with_single_quotes(): + proto = r""" + |message Foo { + | optional string name = 1[default = 'single\"quotes']; + |} + """ + proto = trim_margin(proto) + + field = FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="name", + tag=1, + default_value="single\"quotes" + ) + message_element = MessageElement(location=location.at(1, 1), name="Foo", fields=[field]) + expected = ProtoFileElement(location=location, types=[message_element]) + assert ProtoParser.parse(location, proto) == expected + + +def test_adjacent_strings_concatenated(): + proto = """ + |message Foo { + | optional string name = 1 [ + | default = "concat " + | 'these ' + | "please" + | ]; + |} + """ + proto = trim_margin(proto) + + field = FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="name", + tag=1, + default_value="concat these please" + ) + message_element = MessageElement(location=location.at(1, 1), name="Foo", fields=[field]) + expected = ProtoFileElement(location=location, types=[message_element]) + assert ProtoParser.parse(location, proto) == expected + + +def test_invalid_hex_string_escape(): + proto = r""" + |message Foo { + | optional string name = 1 [default = "\xW"]; + |} + """ + proto = trim_margin(proto) + with pytest.raises(IllegalStateException) as re: + ProtoParser.parse(location, proto) + pytest.fail("") + assert "expected a digit after \\x or \\X" in re.value.message + + +def test_service(): + proto = """ + |service SearchService { + | option (default_timeout) = 30; + | + | rpc Search (SearchRequest) returns (SearchResponse); + | rpc Purchase (PurchaseRequest) returns (PurchaseResponse) { + | option (squareup.sake.timeout) = 15; + | option (squareup.a.b) = { + | value: [ + | FOO, + | BAR + | ] + | }; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + services=[ + ServiceElement( + location=location.at(1, 1), + name="SearchService", + options=[OptionElement("default_timeout", OptionElement.Kind.NUMBER, "30", True)], + rpcs=[ + RpcElement( + location=location.at(4, 3), + name="Search", + request_type="SearchRequest", + response_type="SearchResponse", + response_streaming=False, + request_streaming=False + ), + RpcElement( + location=location.at(5, 3), + name="Purchase", + request_type="PurchaseRequest", + response_type="PurchaseResponse", + options=[ + OptionElement("squareup.sake.timeout", OptionElement.Kind.NUMBER, "15", True), + OptionElement("squareup.a.b", OptionElement.Kind.MAP, {"value": ["FOO", "BAR"]}, True) + ], + request_streaming=False, + response_streaming=False + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_streaming_service(): + proto = """ + |service RouteGuide { + | rpc GetFeature (Point) returns (Feature) {} + | rpc ListFeatures (Rectangle) returns (stream Feature) {} + | rpc RecordRoute (stream Point) returns (RouteSummary) {} + | rpc RouteChat (stream RouteNote) returns (stream RouteNote) {} + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + services=[ + ServiceElement( + location=location.at(1, 1), + name="RouteGuide", + rpcs=[ + RpcElement( + location=location.at(2, 3), + name="GetFeature", + request_type="Point", + response_type="Feature", + response_streaming=False, + request_streaming=False + ), + RpcElement( + location=location.at(3, 3), + name="ListFeatures", + request_type="Rectangle", + response_type="Feature", + response_streaming=True, + # TODO: Report Square.Wire there was mistake True instead of False! + request_streaming=False, + ), + RpcElement( + location=location.at(4, 3), + name="RecordRoute", + request_type="Point", + response_type="RouteSummary", + request_streaming=True, + response_streaming=False, + ), + RpcElement( + location=location.at(5, 3), + name="RouteChat", + request_type="RouteNote", + response_type="RouteNote", + request_streaming=True, + response_streaming=True, + ) + ], + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_hex_tag(): + proto = """ + |message HexTag { + | required string hex = 0x10; + | required string uppercase_x_hex = 0X11; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="HexTag", + fields=[ + FieldElement( + location=location.at(2, 3), label=Field.Label.REQUIRED, element_type="string", name="hex", tag=16 + ), + FieldElement( + location=location.at(3, 3), + label=Field.Label.REQUIRED, + element_type="string", + name="uppercase_x_hex", + tag=17 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_structured_option(): + proto = """ + |message ExoticOptions { + | option (squareup.one) = {name: "Name", class_name:"ClassName"}; + | option (squareup.two.a) = {[squareup.options.type]: EXOTIC}; + | option (squareup.two.b) = {names: ["Foo", "Bar"]}; + |} + """ + # TODO: we do not support it yet + # + # | option (squareup.three) = {x: {y: 1 y: 2 } }; // NOTE: Omitted optional comma + # | option (squareup.four) = {x: {y: {z: 1 }, y: {z: 2 }}}; + # + # + # + proto = trim_margin(proto) + + option_one_map = {"name": "Name", "class_name": "ClassName"} + + option_two_a_map = {"[squareup.options.type]": "EXOTIC"} + + option_two_b_map = {"names": ["Foo", "Bar"]} + + # TODO: we do not support it yet + # need create custom dictionary class to support multiple values for one key + # + # option_three_map = {"x": {"y": 1, "y": 2}} + # option_four_map = {"x": ["y": {"z": 1}, "y": {"z": 2}]} + + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="ExoticOptions", + options=[ + OptionElement("squareup.one", OptionElement.Kind.MAP, option_one_map, True), + OptionElement("squareup.two.a", OptionElement.Kind.MAP, option_two_a_map, True), + OptionElement("squareup.two.b", OptionElement.Kind.MAP, option_two_b_map, True), + # OptionElement("squareup.three", OptionElement.Kind.MAP, option_three_map, True), + # OptionElement("squareup.four", OptionElement.Kind.MAP, option_four_map, True) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_options_with_nested_maps_and_trailing_commas(): + proto = """ + |message StructuredOption { + | optional field.type has_options = 3 [ + | (option_map) = { + | nested_map: {key:"value", key2:["value2a","value2b"]}, + | }, + | (option_string) = ["string1","string2"] + | ]; + |} + """ + proto = trim_margin(proto) + field = FieldElement( + location=location.at(2, 5), + label=Field.Label.OPTIONAL, + element_type="field.type", + name="has_options", + tag=3, + options=[ + OptionElement( + "option_map", OptionElement.Kind.MAP, {"nested_map": { + "key": "value", + "key2": ["value2a", "value2b"] + }}, True + ), + OptionElement("option_string", OptionElement.Kind.LIST, ["string1", "string2"], True) + ] + ) + assert len(field.options) == 2 + assert OptionElement( + "option_map", OptionElement.Kind.MAP, {"nested_map": { + "key": "value", + "key2": ["value2a", "value2b"] + }}, True + ) in field.options + assert OptionElement("option_string", OptionElement.Kind.LIST, ["string1", "string2"], True) in field.options + + expected = MessageElement(location=location.at(1, 1), name="StructuredOption", fields=[field]) + proto_file = ProtoFileElement(location=location, types=[expected]) + assert ProtoParser.parse(location, proto) == proto_file + + +def test_option_numerical_bounds(): + proto = r""" + |message Test { + | optional int32 default_int32 = 401 [x = 2147483647]; + | optional uint32 default_uint32 = 402 [x = 4294967295]; + | optional sint32 default_sint32 = 403 [x = -2147483648]; + | optional fixed32 default_fixed32 = 404 [x = 4294967295]; + | optional sfixed32 default_sfixed32 = 405 [x = -2147483648]; + | optional int64 default_int64 = 406 [x = 9223372036854775807]; + | optional uint64 default_uint64 = 407 [x = 18446744073709551615]; + | optional sint64 default_sint64 = 408 [x = -9223372036854775808]; + | optional fixed64 default_fixed64 = 409 [x = 18446744073709551615]; + | optional sfixed64 default_sfixed64 = 410 [x = -9223372036854775808]; + | optional bool default_bool = 411 [x = true]; + | optional float default_float = 412 [x = 123.456e7]; + | optional double default_double = 413 [x = 123.456e78]; + | optional string default_string = 414 """ + \ + r"""[x = "çok\a\b\f\n\r\t\v\1\01\001\17\017\176\x1\x01\x11\X1\X01\X11güzel" ]; + | optional bytes default_bytes = 415 """ + \ + r"""[x = "çok\a\b\f\n\r\t\v\1\01\001\17\017\176\x1\x01\x11\X1\X01\X11güzel" ]; + | optional NestedEnum default_nested_enum = 416 [x = A ]; + |}""" + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="Test", + fields=[ + FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="int32", + name="default_int32", + tag=401, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "2147483647")] + ), + FieldElement( + location=location.at(3, 3), + label=Field.Label.OPTIONAL, + element_type="uint32", + name="default_uint32", + tag=402, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "4294967295")] + ), + FieldElement( + location=location.at(4, 3), + label=Field.Label.OPTIONAL, + element_type="sint32", + name="default_sint32", + tag=403, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "-2147483648")] + ), + FieldElement( + location=location.at(5, 3), + label=Field.Label.OPTIONAL, + element_type="fixed32", + name="default_fixed32", + tag=404, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "4294967295")] + ), + FieldElement( + location=location.at(6, 3), + label=Field.Label.OPTIONAL, + element_type="sfixed32", + name="default_sfixed32", + tag=405, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "-2147483648")] + ), + FieldElement( + location=location.at(7, 3), + label=Field.Label.OPTIONAL, + element_type="int64", + name="default_int64", + tag=406, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "9223372036854775807")] + ), + FieldElement( + location=location.at(8, 3), + label=Field.Label.OPTIONAL, + element_type="uint64", + name="default_uint64", + tag=407, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "18446744073709551615")] + ), + FieldElement( + location=location.at(9, 3), + label=Field.Label.OPTIONAL, + element_type="sint64", + name="default_sint64", + tag=408, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "-9223372036854775808")] + ), + FieldElement( + location=location.at(10, 3), + label=Field.Label.OPTIONAL, + element_type="fixed64", + name="default_fixed64", + tag=409, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "18446744073709551615")] + ), + FieldElement( + location=location.at(11, 3), + label=Field.Label.OPTIONAL, + element_type="sfixed64", + name="default_sfixed64", + tag=410, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "-9223372036854775808")] + ), + FieldElement( + location=location.at(12, 3), + label=Field.Label.OPTIONAL, + element_type="bool", + name="default_bool", + tag=411, + options=[OptionElement("x", OptionElement.Kind.BOOLEAN, "true")] + ), + FieldElement( + location=location.at(13, 3), + label=Field.Label.OPTIONAL, + element_type="float", + name="default_float", + tag=412, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "123.456e7")] + ), + FieldElement( + location=location.at(14, 3), + label=Field.Label.OPTIONAL, + element_type="double", + name="default_double", + tag=413, + options=[OptionElement("x", OptionElement.Kind.NUMBER, "123.456e78")] + ), + FieldElement( + location=location.at(15, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="default_string", + tag=414, + options=[ + OptionElement( + "x", OptionElement.Kind.STRING, + "çok\u0007\b\u000C\n\r\t\u000b\u0001\u0001\u0001\u000f\u000f~\u0001\u0001\u0011" + "\u0001\u0001\u0011güzel" + ) + ] + ), + FieldElement( + location=location.at(17, 3), + label=Field.Label.OPTIONAL, + element_type="bytes", + name="default_bytes", + tag=415, + options=[ + OptionElement( + "x", OptionElement.Kind.STRING, + "çok\u0007\b\u000C\n\r\t\u000b\u0001\u0001\u0001\u000f\u000f~\u0001\u0001\u0011" + "\u0001\u0001\u0011güzel" + ) + ] + ), + FieldElement( + location=location.at(19, 3), + label=Field.Label.OPTIONAL, + element_type="NestedEnum", + name="default_nested_enum", + tag=416, + options=[OptionElement("x", OptionElement.Kind.ENUM, "A")] + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_extension_with_nested_message(): + proto = """ + |message Foo { + | optional int32 bar = 1[ + | (validation.range).min = 1, + | (validation.range).max = 100, + | old_default = 20 + | ]; + |} + """ + proto = trim_margin(proto) + field = FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="int32", + name="bar", + tag=1, + options=[ + OptionElement( + "validation.range", OptionElement.Kind.OPTION, OptionElement("min", OptionElement.Kind.NUMBER, "1"), True + ), + OptionElement( + "validation.range", OptionElement.Kind.OPTION, OptionElement("max", OptionElement.Kind.NUMBER, "100"), True + ), + OptionElement("old_default", OptionElement.Kind.NUMBER, "20") + ] + ) + assert len(field.options) == 3 + assert OptionElement( + "validation.range", OptionElement.Kind.OPTION, OptionElement("min", OptionElement.Kind.NUMBER, "1"), True + ) in field.options + + assert OptionElement( + "validation.range", OptionElement.Kind.OPTION, OptionElement("max", OptionElement.Kind.NUMBER, "100"), True + ) in field.options + + assert OptionElement("old_default", OptionElement.Kind.NUMBER, "20") in field.options + + expected = MessageElement(location=location.at(1, 1), name="Foo", fields=[field]) + proto_file = ProtoFileElement(location=location, types=[expected]) + assert ProtoParser.parse(location, proto) == proto_file + + +def test_reserved(): + proto = """ + |message Foo { + | reserved 10, 12 to 14, 'foo'; + |} + """ + proto = trim_margin(proto) + message = MessageElement( + location=location.at(1, 1), + name="Foo", + reserveds=[ReservedElement(location=location.at(2, 3), values=[10, KotlinRange(12, 14), "foo"])] + ) + expected = ProtoFileElement(location=location, types=[message]) + assert ProtoParser.parse(location, proto) == expected + + +def test_reserved_with_comments(): + proto = """ + |message Foo { + | optional string a = 1; // This is A. + | reserved 2; // This is reserved. + | optional string c = 3; // This is C. + |} + """ + proto = trim_margin(proto) + message = MessageElement( + location=location.at(1, 1), + name="Foo", + fields=[ + FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="a", + tag=1, + documentation="This is A." + ), + FieldElement( + location=location.at(4, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="c", + tag=3, + documentation="This is C." + ) + ], + reserveds=[ReservedElement(location=location.at(3, 3), values=[2], documentation="This is reserved.")] + ) + expected = ProtoFileElement(location=location, types=[message]) + assert ProtoParser.parse(location, proto) == expected + + +def test_no_whitespace(): + proto = "message C {optional A.B ab = 1;}" + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="C", + fields=[ + FieldElement( + location=location.at(1, 12), label=Field.Label.OPTIONAL, element_type="A.B", name="ab", tag=1 + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_deep_option_assignments(): + proto = """ + |message Foo { + | optional string a = 1 [(wire.my_field_option).baz.value = "a"]; + |} + |""" + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="Foo", + fields=[ + FieldElement( + location=location.at(2, 3), + label=Field.Label.OPTIONAL, + element_type="string", + name="a", + tag=1, + options=[ + OptionElement( + name="wire.my_field_option", + kind=OptionElement.Kind.OPTION, + is_parenthesized=True, + value=OptionElement( + name="baz", + kind=OptionElement.Kind.OPTION, + is_parenthesized=False, + value=OptionElement( + name="value", kind=OptionElement.Kind.STRING, is_parenthesized=False, value="a" + ) + ) + ) + ] + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto_keyword_as_enum_constants(): + # Note: this is consistent with protoc. + proto = """ + |enum Foo { + | syntax = 0; + | import = 1; + | package = 2; + | // option = 3; + | // reserved = 4; + | message = 5; + | enum = 6; + | service = 7; + | extend = 8; + | rpc = 9; + | oneof = 10; + | extensions = 11; + |} + |""" + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + EnumElement( + location=location.at(1, 1), + name="Foo", + constants=[ + EnumConstantElement(location.at(2, 3), "syntax", 0), + EnumConstantElement(location.at(3, 3), "import", 1), + EnumConstantElement(location.at(4, 3), "package", 2), + EnumConstantElement(location.at(7, 3), "message", 5, documentation="option = 3;\nreserved = 4;"), + EnumConstantElement(location.at(8, 3), "enum", 6), + EnumConstantElement(location.at(9, 3), "service", 7), + EnumConstantElement(location.at(10, 3), "extend", 8), + EnumConstantElement(location.at(11, 3), "rpc", 9), + EnumConstantElement(location.at(12, 3), "oneof", 10), + EnumConstantElement(location.at(13, 3), "extensions", 11), + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto_keyword_as_message_name_and_field_proto2(): + # Note: this is consistent with protoc. + proto = """ + |message syntax { + | optional syntax syntax = 1; + |} + |message import { + | optional import import = 1; + |} + |message package { + | optional package package = 1; + |} + |message option { + | optional option option = 1; + |} + |message reserved { + | optional reserved reserved = 1; + |} + |message message { + | optional message message = 1; + |} + |message enum { + | optional enum enum = 1; + |} + |message service { + | optional service service = 1; + |} + |message extend { + | optional extend extend = 1; + |} + |message rpc { + | optional rpc rpc = 1; + |} + |message oneof { + | optional oneof oneof = 1; + |} + |message extensions { + | optional extensions extensions = 1; + |} + |""" + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="syntax", + fields=[ + FieldElement(location.at(2, 3), label=Field.Label.OPTIONAL, element_type="syntax", name="syntax", tag=1) + ] + ), + MessageElement( + location=location.at(4, 1), + name="import", + fields=[ + FieldElement(location.at(5, 3), label=Field.Label.OPTIONAL, element_type="import", name="import", tag=1) + ] + ), + MessageElement( + location=location.at(7, 1), + name="package", + fields=[ + FieldElement( + location.at(8, 3), label=Field.Label.OPTIONAL, element_type="package", name="package", tag=1 + ) + ] + ), + MessageElement( + location=location.at(10, 1), + name="option", + fields=[ + FieldElement( + location.at(11, 3), label=Field.Label.OPTIONAL, element_type="option", name="option", tag=1 + ) + ] + ), + MessageElement( + location=location.at(13, 1), + name="reserved", + fields=[ + FieldElement( + location.at(14, 3), label=Field.Label.OPTIONAL, element_type="reserved", name="reserved", tag=1 + ) + ] + ), + MessageElement( + location=location.at(16, 1), + name="message", + fields=[ + FieldElement( + location.at(17, 3), label=Field.Label.OPTIONAL, element_type="message", name="message", tag=1 + ) + ] + ), + MessageElement( + location=location.at(19, 1), + name="enum", + fields=[ + FieldElement(location.at(20, 3), label=Field.Label.OPTIONAL, element_type="enum", name="enum", tag=1) + ] + ), + MessageElement( + location=location.at(22, 1), + name="service", + fields=[ + FieldElement( + location.at(23, 3), label=Field.Label.OPTIONAL, element_type="service", name="service", tag=1 + ) + ] + ), + MessageElement( + location=location.at(25, 1), + name="extend", + fields=[ + FieldElement( + location.at(26, 3), label=Field.Label.OPTIONAL, element_type="extend", name="extend", tag=1 + ) + ] + ), + MessageElement( + location=location.at(28, 1), + name="rpc", + fields=[FieldElement(location.at(29, 3), label=Field.Label.OPTIONAL, element_type="rpc", name="rpc", tag=1)] + ), + MessageElement( + location=location.at(31, 1), + name="oneof", + fields=[ + FieldElement(location.at(32, 3), label=Field.Label.OPTIONAL, element_type="oneof", name="oneof", tag=1) + ] + ), + MessageElement( + location=location.at(34, 1), + name="extensions", + fields=[ + FieldElement( + location.at(35, 3), label=Field.Label.OPTIONAL, element_type="extensions", name="extensions", tag=1 + ) + ] + ), + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto_keyword_as_message_name_and_field_proto3(): + # Note: this is consistent with protoc. + proto = """ + |syntax = "proto3"; + |message syntax { + | syntax syntax = 1; + |} + |message import { + | import import = 1; + |} + |message package { + | package package = 1; + |} + |message option { + | option option = 1; + |} + |message reserved { + | // reserved reserved = 1; + |} + |message message { + | // message message = 1; + |} + |message enum { + | // enum enum = 1; + |} + |message service { + | service service = 1; + |} + |message extend { + | // extend extend = 1; + |} + |message rpc { + | rpc rpc = 1; + |} + |message oneof { + | // oneof oneof = 1; + |} + |message extensions { + | // extensions extensions = 1; + |} + |""" + + proto = trim_margin(proto) + expected = ProtoFileElement( + syntax=Syntax.PROTO_3, + location=location, + types=[ + MessageElement( + location=location.at(2, 1), + name="syntax", + fields=[FieldElement(location.at(3, 3), element_type="syntax", name="syntax", tag=1)] + ), + MessageElement( + location=location.at(5, 1), + name="import", + fields=[FieldElement(location.at(6, 3), element_type="import", name="import", tag=1)] + ), + MessageElement( + location=location.at(8, 1), + name="package", + fields=[FieldElement(location.at(9, 3), element_type="package", name="package", tag=1)] + ), + MessageElement( + location=location.at(11, 1), + name="option", + options=[OptionElement(name="option", kind=OptionElement.Kind.NUMBER, value="1", is_parenthesized=False)], + ), + MessageElement( + location=location.at(14, 1), + name="reserved", + ), + MessageElement( + location=location.at(17, 1), + name="message", + ), + MessageElement( + location=location.at(20, 1), + name="enum", + ), + MessageElement( + location=location.at(23, 1), + name="service", + fields=[FieldElement(location.at(24, 3), element_type="service", name="service", tag=1)] + ), + MessageElement( + location=location.at(26, 1), + name="extend", + ), + MessageElement( + location=location.at(29, 1), + name="rpc", + fields=[FieldElement(location.at(30, 3), element_type="rpc", name="rpc", tag=1)] + ), + MessageElement( + location=location.at(32, 1), + name="oneof", + ), + MessageElement( + location=location.at(35, 1), + name="extensions", + ), + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_proto_keyword_as_service_name_and_rpc(): + # Note: this is consistent with protoc. + proto = """ + |service syntax { + | rpc syntax (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service import { + | rpc import (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service package { + | rpc package (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service option { + | rpc option (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service reserved { + | rpc reserved (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service message { + | rpc message (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service enum { + | rpc enum (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service service { + | rpc service (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service extend { + | rpc extend (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service rpc { + | rpc rpc (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service oneof { + | rpc oneof (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |service extensions { + | rpc extensions (google.protobuf.StringValue) returns (google.protobuf.StringValue); + |} + |""" + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + services=[ + ServiceElement( + location=location.at(1, 1), + name="syntax", + rpcs=[ + RpcElement( + location.at(2, 3), + name="syntax", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(4, 1), + name="import", + rpcs=[ + RpcElement( + location.at(5, 3), + name="import", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(7, 1), + name="package", + rpcs=[ + RpcElement( + location.at(8, 3), + name="package", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(10, 1), + name="option", + rpcs=[ + RpcElement( + location.at(11, 3), + name="option", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(13, 1), + name="reserved", + rpcs=[ + RpcElement( + location.at(14, 3), + name="reserved", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(16, 1), + name="message", + rpcs=[ + RpcElement( + location.at(17, 3), + name="message", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(19, 1), + name="enum", + rpcs=[ + RpcElement( + location.at(20, 3), + name="enum", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(22, 1), + name="service", + rpcs=[ + RpcElement( + location.at(23, 3), + name="service", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(25, 1), + name="extend", + rpcs=[ + RpcElement( + location.at(26, 3), + name="extend", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(28, 1), + name="rpc", + rpcs=[ + RpcElement( + location.at(29, 3), + name="rpc", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue", + ) + ] + ), + ServiceElement( + location=location.at(31, 1), + name="oneof", + rpcs=[ + RpcElement( + location.at(32, 3), + name="oneof", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue" + ) + ] + ), + ServiceElement( + location=location.at(34, 1), + name="extensions", + rpcs=[ + RpcElement( + location.at(35, 3), + name="extensions", + request_type="google.protobuf.StringValue", + response_type="google.protobuf.StringValue" + ) + ] + ), + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_forbid_multiple_syntax_definitions(): + proto = """ + | syntax = "proto2"; + | syntax = "proto2"; + """ + proto = trim_margin(proto) + with pytest.raises(IllegalStateException, match="Syntax error in file.proto:2:3: too many syntax definitions"): + # TODO: this test in Kotlin source contains "2:13:" Need compile square.wire and check how it can be? + ProtoParser.parse(location, proto) + pytest.fail("") + + +def test_one_of_options(): + proto = """ + |message SearchRequest { + | required string query = 1; + | oneof page_info { + | option (my_option) = true; + | int32 page_number = 2; + | int32 result_per_page = 3; + | } + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + types=[ + MessageElement( + location=location.at(1, 1), + name="SearchRequest", + fields=[ + FieldElement( + location=location.at(2, 3), label=Field.Label.REQUIRED, element_type="string", name="query", tag=1 + ) + ], + one_ofs=[ + OneOfElement( + name="page_info", + fields=[ + FieldElement(location=location.at(5, 5), element_type="int32", name="page_number", tag=2), + FieldElement(location=location.at(6, 5), element_type="int32", name="result_per_page", tag=3) + ], + options=[ + OptionElement("my_option", OptionElement.Kind.BOOLEAN, value="true", is_parenthesized=True) + ] + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected + + +def test_semi_colon_as_options_delimiters(): + proto = """ + |service MyService { + | option (custom_rule) = { + | my_string: "abc"; my_int: 3; + | my_list: ["a", "b", "c"]; + | }; + |} + """ + proto = trim_margin(proto) + expected = ProtoFileElement( + location=location, + services=[ + ServiceElement( + location=location.at(1, 1), + name="MyService", + options=[ + OptionElement( + "custom_rule", + OptionElement.Kind.MAP, { + "my_string": "abc", + "my_int": "3", + "my_list": ["a", "b", "c"] + }, + is_parenthesized=True + ) + ] + ) + ] + ) + assert ProtoParser.parse(location, proto) == expected diff --git a/tests/unit/test_protobuf_schema.py b/tests/unit/test_protobuf_schema.py new file mode 100644 index 000000000..bece9b241 --- /dev/null +++ b/tests/unit/test_protobuf_schema.py @@ -0,0 +1,288 @@ +from karapace.protobuf.compare_result import CompareResult +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.schema import ProtobufSchema +from karapace.schema_reader import SchemaType, TypedSchema +from tests.schemas.protobuf import ( + schema_protobuf_compare_one, schema_protobuf_order_after, schema_protobuf_order_before, schema_protobuf_schema_registry1 +) + +location: Location = Location.get("file.proto") + + +def test_protobuf_schema_simple(): + proto = trim_margin(schema_protobuf_schema_registry1) + protobuf_schema = TypedSchema.parse(SchemaType.PROTOBUF, proto) + result = str(protobuf_schema) + + assert result == proto + + +def test_protobuf_schema_sort(): + proto = trim_margin(schema_protobuf_order_before) + protobuf_schema = TypedSchema.parse(SchemaType.PROTOBUF, proto) + result = str(protobuf_schema) + proto2 = trim_margin(schema_protobuf_order_after) + assert result == proto2 + + +def test_protobuf_schema_compare(): + proto1 = trim_margin(schema_protobuf_order_after) + protobuf_schema1: TypedSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1) + proto2 = trim_margin(schema_protobuf_compare_one) + protobuf_schema2: TypedSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2) + result = CompareResult() + protobuf_schema1.schema.compare(protobuf_schema2.schema, result) + assert result.is_compatible() + + +def test_protobuf_schema_compare2(): + proto1 = trim_margin(schema_protobuf_order_after) + protobuf_schema1: TypedSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1) + proto2 = trim_margin(schema_protobuf_compare_one) + protobuf_schema2: TypedSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2) + result = CompareResult() + protobuf_schema2.schema.compare(protobuf_schema1.schema, result) + assert result.is_compatible() + + +def test_protobuf_schema_compare3(): + proto1 = """ + |syntax = "proto3"; + |package a1; + |message TestMessage { + | message Value { + | string str2 = 1; + | int32 x = 2; + | } + | string test = 1; + | .a1.TestMessage.Value val = 2; + |} + |""" + + proto1 = trim_margin(proto1) + + proto2 = """ + |syntax = "proto3"; + |package a1; + | + |message TestMessage { + | string test = 1; + | .a1.TestMessage.Value val = 2; + | + | message Value { + | string str2 = 1; + | Enu x = 2; + | } + | enum Enu { + | A = 0; + | B = 1; + | } + |} + |""" + + proto2 = trim_margin(proto2) + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert result.is_compatible() + + +def test_protobuf_message_compatible_label_alter(): + proto1 = """ + |syntax = "proto3"; + |message Goods { + | optional Packet record = 1; + | string driver = 2; + | message Packet { + | bytes order = 1; + | } + |} + |""" + proto1 = trim_margin(proto1) + + proto2 = """ + |syntax = "proto3"; + |message Goods { + | repeated Packet record = 1; + | string driver = 2; + | message Packet { + | bytes order = 1; + | } + |} + |""" + + proto2 = trim_margin(proto2) + + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert result.is_compatible() + + +def test_protobuf_field_type_incompatible_alter(): + proto1 = """ + |syntax = "proto3"; + |message Goods { + | string order = 1; + | map items_int32 = 2; + |} + |""" + proto1 = trim_margin(proto1) + + proto2 = """ + |syntax = "proto3"; + |message Goods { + | string order = 1; + | map items_string = 2; + |} + |""" + + proto2 = trim_margin(proto2) + + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert not result.is_compatible() + + +def test_protobuf_field_label_compatible_alter(): + proto1 = """ + |syntax = "proto3"; + |message Goods { + | optional string driver = 1; + | Order order = 2; + | message Order { + | string item = 1; + | } + |} + |""" + + proto1 = trim_margin(proto1) + proto2 = """ + |syntax = "proto3"; + |message Goods { + | repeated string driver = 1; + | Order order = 2; + | message Order { + | string item = 1; + | } + |} + |""" + + proto2 = trim_margin(proto2) + + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert result.is_compatible() + + +def test_protobuf_field_incompatible_drop_from_oneof(): + proto1 = """ + |syntax = "proto3"; + |message Goods { + | oneof item { + | string name_a = 1; + | string name_b = 2; + | int32 id = 3; + | } + |} + |""" + + proto1 = trim_margin(proto1) + proto2 = """ + |syntax = "proto3"; + |message Goods { + | oneof item { + | string name_a = 1; + | string name_b = 2; + | } + |} + |""" + + proto2 = trim_margin(proto2) + + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert not result.is_compatible() + + +def test_protobuf_field_incompatible_alter_to_oneof(): + proto1 = """ + |syntax = "proto3"; + |message Goods { + | string name = 1; + | string reg_name = 2; + |} + |""" + + proto1 = trim_margin(proto1) + proto2 = """ + |syntax = "proto3"; + |message Goods { + | oneof reg_data { + | string name = 1; + | string reg_name = 2; + | int32 id = 3; + | } + |} + |""" + + proto2 = trim_margin(proto2) + + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert not result.is_compatible() + + +def test_protobuf_field_compatible_alter_to_oneof(): + proto1 = """ + |syntax = "proto3"; + |message Goods { + | string name = 1; + | string foo = 2; + |} + |""" + + proto1 = trim_margin(proto1) + proto2 = """ + |syntax = "proto3"; + |message Goods { + | string name = 1; + | oneof new_oneof { + | string foo = 2; + | int32 bar = 3; + | } + |} + |""" + + proto2 = trim_margin(proto2) + + protobuf_schema1: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto1).schema + protobuf_schema2: ProtobufSchema = TypedSchema.parse(SchemaType.PROTOBUF, proto2).schema + result = CompareResult() + + protobuf_schema1.compare(protobuf_schema2, result) + + assert result.is_compatible() diff --git a/tests/unit/test_protobuf_serialization.py b/tests/unit/test_protobuf_serialization.py new file mode 100644 index 000000000..b8183878f --- /dev/null +++ b/tests/unit/test_protobuf_serialization.py @@ -0,0 +1,72 @@ +from karapace.config import read_config +from karapace.serialization import ( + InvalidMessageHeader, InvalidMessageSchema, InvalidPayload, SchemaRegistryDeserializer, SchemaRegistrySerializer, + START_BYTE +) +from tests.utils import test_fail_objects_protobuf, test_objects_protobuf + +import logging +import pytest +import struct + +log = logging.getLogger(__name__) + + +async def make_ser_deser(config_path, mock_client): + with open(config_path) as handler: + config = read_config(handler) + serializer = SchemaRegistrySerializer(config_path=config_path, config=config) + deserializer = SchemaRegistryDeserializer(config_path=config_path, config=config) + await serializer.registry_client.close() + await deserializer.registry_client.close() + serializer.registry_client = mock_client + deserializer.registry_client = mock_client + return serializer, deserializer + + +async def test_happy_flow(default_config_path, mock_protobuf_registry_client): + serializer, deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) + for o in serializer, deserializer: + assert len(o.ids_to_schemas) == 0 + schema = await serializer.get_schema_for_subject("top") + for o in test_objects_protobuf: + a = await serializer.serialize(schema, o) + u = await deserializer.deserialize(a) + assert o == u + for o in serializer, deserializer: + assert len(o.ids_to_schemas) == 1 + assert 1 in o.ids_to_schemas + + +async def test_serialization_fails(default_config_path, mock_protobuf_registry_client): + serializer, _ = await make_ser_deser(default_config_path, mock_protobuf_registry_client) + with pytest.raises(InvalidMessageSchema): + schema = await serializer.get_schema_for_subject("top") + await serializer.serialize(schema, test_fail_objects_protobuf[0]) + + with pytest.raises(InvalidMessageSchema): + schema = await serializer.get_schema_for_subject("top") + await serializer.serialize(schema, test_fail_objects_protobuf[1]) + + +async def test_deserialization_fails(default_config_path, mock_protobuf_registry_client): + _, deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) + invalid_header_payload = struct.pack(">bII", 1, 500, 500) + with pytest.raises(InvalidMessageHeader): + await deserializer.deserialize(invalid_header_payload) + + # wrong schema id (500) + invalid_data_payload = struct.pack(">bII", START_BYTE, 500, 500) + with pytest.raises(InvalidPayload): + await deserializer.deserialize(invalid_data_payload) + + +async def test_deserialization_fails2(default_config_path, mock_protobuf_registry_client): + _, deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) + invalid_header_payload = struct.pack(">bII", 1, 500, 500) + with pytest.raises(InvalidMessageHeader): + await deserializer.deserialize(invalid_header_payload) + + enc_bytes = b'\x00\x00\x00\x00\x01\x00\x02\x05\0x12' # wrong schema data (2) + with pytest.raises(InvalidPayload): + await deserializer.deserialize(enc_bytes) diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index 146a3c745..3ce4a19c0 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -9,9 +9,12 @@ import copy import io import json +import logging import pytest import struct +log = logging.getLogger(__name__) + async def make_ser_deser(config_path, mock_client): with open(config_path) as handler: diff --git a/tests/unit/test_service_element.py b/tests/unit/test_service_element.py new file mode 100644 index 000000000..9f4935f15 --- /dev/null +++ b/tests/unit/test_service_element.py @@ -0,0 +1,151 @@ +# Ported from square/wire: +# wire-library/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/internal/parser/ServiceElementTest.kt + +from karapace.protobuf.kotlin_wrapper import trim_margin +from karapace.protobuf.location import Location +from karapace.protobuf.option_element import OptionElement +from karapace.protobuf.rpc_element import RpcElement +from karapace.protobuf.service_element import ServiceElement + +location: Location = Location.get("file.proto") + + +def test_empty_to_schema(): + service = ServiceElement(location=location, name="Service") + expected = "service Service {}\n" + assert service.to_schema() == expected + + +def test_single_to_schema(): + service = ServiceElement( + location=location, + name="Service", + rpcs=[RpcElement(location=location, name="Name", request_type="RequestType", response_type="ResponseType")] + ) + expected = """ + |service Service { + | rpc Name (RequestType) returns (ResponseType); + |} + |""" + expected = trim_margin(expected) + assert service.to_schema() == expected + + +def test_add_multiple_rpcs(): + first_name = RpcElement(location=location, name="FirstName", request_type="RequestType", response_type="ResponseType") + last_name = RpcElement(location=location, name="LastName", request_type="RequestType", response_type="ResponseType") + service = ServiceElement(location=location, name="Service", rpcs=[first_name, last_name]) + assert len(service.rpcs) == 2 + + +def test_single_with_options_to_schema(): + service = ServiceElement( + location=location, + name="Service", + options=[OptionElement("foo", OptionElement.Kind.STRING, "bar")], + rpcs=[RpcElement(location=location, name="Name", request_type="RequestType", response_type="ResponseType")] + ) + expected = """ + |service Service { + | option foo = "bar"; + | + | rpc Name (RequestType) returns (ResponseType); + |} + |""" + expected = trim_margin(expected) + assert service.to_schema() == expected + + +def test_add_multiple_options(): + kit_kat = OptionElement("kit", OptionElement.Kind.STRING, "kat") + foo_bar = OptionElement("foo", OptionElement.Kind.STRING, "bar") + service = ServiceElement( + location=location, + name="Service", + options=[kit_kat, foo_bar], + rpcs=[RpcElement(location=location, name="Name", request_type="RequestType", response_type="ResponseType")] + ) + assert len(service.options) == 2 + + +def test_single_with_documentation_to_schema(): + service = ServiceElement( + location=location, + name="Service", + documentation="Hello", + rpcs=[RpcElement(location=location, name="Name", request_type="RequestType", response_type="ResponseType")] + ) + expected = """ + |// Hello + |service Service { + | rpc Name (RequestType) returns (ResponseType); + |} + |""" + expected = trim_margin(expected) + assert service.to_schema() == expected + + +def test_multiple_to_schema(): + rpc = RpcElement(location=location, name="Name", request_type="RequestType", response_type="ResponseType") + service = ServiceElement(location=location, name="Service", rpcs=[rpc, rpc]) + expected = """ + |service Service { + | rpc Name (RequestType) returns (ResponseType); + | rpc Name (RequestType) returns (ResponseType); + |} + |""" + expected = trim_margin(expected) + + assert service.to_schema() == expected + + +def test_rpc_to_schema(): + rpc = RpcElement(location=location, name="Name", request_type="RequestType", response_type="ResponseType") + expected = "rpc Name (RequestType) returns (ResponseType);\n" + assert rpc.to_schema() == expected + + +def test_rpc_with_documentation_to_schema(): + rpc = RpcElement( + location=location, name="Name", documentation="Hello", request_type="RequestType", response_type="ResponseType" + ) + expected = """ + |// Hello + |rpc Name (RequestType) returns (ResponseType); + |""" + expected = trim_margin(expected) + assert rpc.to_schema() == expected + + +def test_rpc_with_options_to_schema(): + rpc = RpcElement( + location=location, + name="Name", + request_type="RequestType", + response_type="ResponseType", + options=[OptionElement("foo", OptionElement.Kind.STRING, "bar")] + ) + + expected = """ + |rpc Name (RequestType) returns (ResponseType) { + | option foo = "bar"; + |}; + |""" + expected = trim_margin(expected) + assert rpc.to_schema() == expected + + +def test_rpc_with_request_streaming_to_schema(): + rpc = RpcElement( + location=location, name="Name", request_type="RequestType", response_type="ResponseType", request_streaming=True + ) + expected = "rpc Name (stream RequestType) returns (ResponseType);\n" + assert rpc.to_schema() == expected + + +def test_rpc_with_response_streaming_to_schema(): + rpc = RpcElement( + location=location, name="Name", request_type="RequestType", response_type="ResponseType", response_streaming=True + ) + expected = "rpc Name (RequestType) returns (stream ResponseType);\n" + assert rpc.to_schema() == expected diff --git a/tests/utils.py b/tests/utils.py index c504dfc66..f8577171a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ from aiohttp.client_exceptions import ClientOSError, ServerDisconnectedError from dataclasses import dataclass from kafka.errors import TopicAlreadyExistsError +from karapace.protobuf.kotlin_wrapper import trim_margin from karapace.utils import Client from typing import Callable, List from urllib.parse import quote @@ -64,11 +65,126 @@ }, ] +# protobuf schemas in tests must be filtered by trim_margin() from kotlin_wrapper module + +schema_protobuf = """ +|syntax = "proto3"; +| +|option java_package = "com.codingharbour.protobuf"; +|option java_outer_classname = "TestEnumOrder"; +| +|message Message { +| int32 query = 1; +| Enum speed = 2; +|} +|enum Enum { +| HIGH = 0; +| MIDDLE = 1; +| LOW = 2; +|} +| +""" +schema_protobuf = trim_margin(schema_protobuf) + +schema_protobuf2 = """ +|syntax = "proto3"; +| +|option java_package = "com.codingharbour.protobuf"; +|option java_outer_classname = "TestEnumOrder"; +| +|message Message { +| int32 query = 1; +|} +|enum Enum { +| HIGH = 0; +| MIDDLE = 1; +| LOW = 2; +|} +| +""" +schema_protobuf2 = trim_margin(schema_protobuf2) + +test_objects_protobuf = [ + { + 'query': 5, + 'speed': 'HIGH' + }, + { + 'query': 10, + 'speed': 'MIDDLE' + }, +] + +test_fail_objects_protobuf = [ + { + 'query': 'STR', + 'speed': 99 + }, + { + 'xx': 10, + 'bb': 'MIDDLE' + }, +] + schema_data = { "avro": (schema_avro_json, test_objects_avro), - "jsonschema": (schema_jsonschema_json, test_objects_jsonschema) + "jsonschema": (schema_jsonschema_json, test_objects_jsonschema), + "protobuf": (schema_protobuf, test_objects_protobuf) } +schema_protobuf_second = """ +|syntax = "proto3"; +| +|option java_package = "com.codingharbour.protobuf"; +|option java_outer_classname = "TestEnumOrder"; +| +|message SensorInfo { +| int32 q = 1; +| Enu sensor_type = 2; +| repeated int32 nums = 3; +| Order order = 4; +| message Order { +| string item = 1; +| } +|} +|enum Enu { +| H1 = 0; +| M1 = 1; +| L1 = 2; +|} +| +""" +schema_protobuf_second = trim_margin(schema_protobuf_second) + +test_objects_protobuf_second = [ + { + 'q': 1, + 'sensor_type': 'H1', + 'nums': [3, 4], + 'order': { + 'item': 'ABC01223' + } + }, + { + 'q': 2, + 'sensor_type': 'M1', + 'nums': [2], + 'order': { + 'item': 'ABC01233' + } + }, + { + 'q': 3, + 'sensor_type': 'L1', + 'nums': [3, 4], + 'order': { + 'item': 'ABC01223' + } + }, +] + +schema_data_second = {"protobuf": (schema_protobuf_second, test_objects_protobuf_second)} + second_schema_json = json.dumps({ "namespace": "example.avro.other", "type": "record", @@ -98,6 +214,10 @@ "Content-Type": "application/vnd.kafka.avro.v2+json", "Accept": "application/vnd.kafka.avro.v2+json, application/vnd.kafka.v2+json, application/json, */*" }, + "protobuf": { + "Content-Type": "application/vnd.kafka.protobuf.v2+json", + "Accept": "application/vnd.kafka.protobuf.v2+json, application/vnd.kafka.v2+json, application/json, */*" + } }