diff --git a/src/karapace/schema_models.py b/src/karapace/schema_models.py index b56fcc123..7f57dc6fd 100644 --- a/src/karapace/schema_models.py +++ b/src/karapace/schema_models.py @@ -5,7 +5,8 @@ from __future__ import annotations from avro.errors import SchemaParseException -from avro.schema import parse as avro_parse, Schema as AvroSchema +from avro.name import Names as AvroNames +from avro.schema import make_avsc_object, parse as avro_parse, Schema as AvroSchema from collections.abc import Collection, Mapping, Sequence from dataclasses import dataclass from jsonschema import Draft7Validator @@ -29,8 +30,8 @@ from karapace.utils import assert_never, json_decode, json_encode, JSONDecodeError from typing import Any, cast, Final, final -import avro.schema import hashlib +import json import logging import re @@ -254,22 +255,31 @@ def parse( if schema_type is SchemaType.AVRO: try: if dependencies: - wrapped_schema_str = AvroMerge(schema_str, dependencies).wrap() + names = AvroNames(validate_names=True) + deps = list(dependencies.values()) + + merged_schema = None + for dep in deps: + # Merge dep with all previously merged ones + json_dep = dep.get_schema().to_dict() + merged_schema = make_avsc_object(json_dep, names) + + # TODO: recursively add the dependencies of this dependency, so that indirect dependencies are + # working (schema1 --> schema2 --> schema3) + + # Merge main schema with all dependencies + schema_json = json.loads(schema_str) + merged_schema = make_avsc_object(schema_json, names) + + merged_schema_str = str(merged_schema) else: - wrapped_schema_str = schema_str + merged_schema_str = schema_str parsed_schema = parse_avro_schema_definition( - wrapped_schema_str, + merged_schema_str, validate_enum_symbols=validate_avro_enum_symbols, validate_names=validate_avro_names, ) - if dependencies: - if isinstance(parsed_schema, avro.schema.UnionSchema): - parsed_schema_result = parsed_schema.schemas[-1].fields[0].type.schemas[-1] - - else: - raise InvalidSchema - else: - parsed_schema_result = parsed_schema + parsed_schema_result = parsed_schema return ParsedTypedSchema( schema_type=schema_type, schema_str=schema_str, diff --git a/tests/integration/test_schema_avro_references.py b/tests/integration/test_schema_avro_references.py index 7beea1f54..a3ccfdccf 100644 --- a/tests/integration/test_schema_avro_references.py +++ b/tests/integration/test_schema_avro_references.py @@ -49,11 +49,35 @@ "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "int"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, {"name": "job", "type": "Job"}, ], } +SCHEMA_PERSON_RECURSIVE = { + "type": "record", + "name": "PersonRecursive", + "namespace": "com.netapp", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"}, + {"name": "job", "type": "Job"}, + {"name": "father", "type": "PersonRecursive"}, + ], +} + +SCHEMA_JOB_INDIRECT_RECURSIVE = { + "type": "record", + "name": "JobIndirectRecursive", + "namespace": "com.netapp", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "salary", "type": "double"}, + {"name": "consultant", "type": "Person"}, + ], +} + + SCHEMA_PERSON_AGE_INT_LONG = { "type": "record", "name": "Person", @@ -61,7 +85,7 @@ "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "long"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, {"name": "job", "type": "Job"}, ], } @@ -73,7 +97,7 @@ "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "string"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, {"name": "job", "type": "Job"}, ], } @@ -85,7 +109,7 @@ "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "int"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, {"name": "job", "type": "Job"}, { "name": "children", @@ -109,7 +133,7 @@ "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "int"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, {"name": "job", "type": "Job"}, ], }, @@ -120,7 +144,7 @@ "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "int"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, ], }, ] @@ -144,16 +168,22 @@ def address_references(subject_prefix: str) -> list: def person_references(subject_prefix: str) -> list: return [ - {"name": "address.avsc", "subject": f"{subject_prefix}address", "version": 1}, + # {"name": "address.avsc", "subject": f"{subject_prefix}address", "version": 1}, {"name": "job.avsc", "subject": f"{subject_prefix}job", "version": 1}, ] +def job_indirect_recursive_references(subject_prefix: str) -> list: + return [ + {"name": "person.avsc", "subject": f"{subject_prefix}person", "version": 1}, + ] + + def stored_person_subject(subject_prefix: str, subject_id: int) -> dict: return { "id": subject_id, "references": [ - {"name": "address.avsc", "subject": f"{subject_prefix}address", "version": 1}, + # {"name": "address.avsc", "subject": f"{subject_prefix}address", "version": 1}, {"name": "job.avsc", "subject": f"{subject_prefix}job", "version": 1}, ], "schema": json.dumps( @@ -161,7 +191,7 @@ def stored_person_subject(subject_prefix: str, subject_id: int) -> dict: "fields": [ {"name": "name", "type": "string"}, {"name": "age", "type": "int"}, - {"name": "address", "type": "Address"}, + # {"name": "address", "type": "Address"}, {"name": "job", "type": "Job"}, ], "name": "Person", @@ -206,12 +236,14 @@ async def basic_avro_references_fill_test(registry_async_client: Client, subject res = await registry_async_client.post(f"subjects/{subject_prefix}job/versions", json={"schema": json.dumps(SCHEMA_JOB)}) assert res.status_code == 200 assert "id" in res.json() + res = await registry_async_client.post( f"subjects/{subject_prefix}person/versions", json={"schemaType": "AVRO", "schema": json.dumps(SCHEMA_PERSON), "references": person_references(subject_prefix)}, ) assert res.status_code == 200 assert "id" in res.json() + return res @@ -293,3 +325,34 @@ async def test_avro_incompatible_name_references(registry_async_client: Client) assert res.status_code == 409 msg = "Incompatible schema, compatibility_mode=BACKWARD. Incompatibilities: expected: com.netapp.Address" assert res.json()["message"] == msg + + +async def test_recursive_reference(registry_async_client: Client) -> None: + subject_prefix = create_subject_name_factory("avro-recursive-reference")() + await basic_avro_references_fill_test(registry_async_client, subject_prefix) + res = await registry_async_client.post( + f"subjects/{subject_prefix}person-recursive/versions", + json={ + "schemaType": "AVRO", + "schema": json.dumps(SCHEMA_PERSON_RECURSIVE), + "references": person_references(subject_prefix), + }, + ) + assert res.status_code == 200 + assert "id" in res.json() + + +# This test fails because indirect references are not implemented +async def test_indirect_recursive_reference(registry_async_client: Client) -> None: + subject_prefix = create_subject_name_factory("avro-indirect-recursive-reference")() + await basic_avro_references_fill_test(registry_async_client, subject_prefix) + res = await registry_async_client.post( + f"subjects/{subject_prefix}person-indirect-recursive/versions", + json={ + "schemaType": "AVRO", + "schema": json.dumps(SCHEMA_JOB_INDIRECT_RECURSIVE), + "references": job_indirect_recursive_references(subject_prefix), + }, + ) + assert res.status_code == 200 + assert "id" in res.json() diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index a21d3bc00..cdb989992 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -4,7 +4,9 @@ """ from karapace.client import Path from karapace.config import DEFAULTS, read_config +from karapace.dependency import Dependency from karapace.schema_models import SchemaType, ValidatedTypedSchema, Versioner +from karapace.schema_references import Reference from karapace.serialization import ( flatten_unions, get_subject_name, @@ -16,7 +18,7 @@ START_BYTE, write_value, ) -from karapace.typing import NameStrategy, Subject, SubjectType +from karapace.typing import NameStrategy, Subject, SubjectType, Version from tests.utils import schema_avro_json, test_objects_avro from unittest.mock import call, Mock @@ -28,6 +30,7 @@ import logging import pytest import struct +import textwrap log = logging.getLogger(__name__) @@ -440,3 +443,89 @@ def test_name_strategy_for_protobuf(expected_subject: Subject, strategy: NameStr get_subject_name(topic_name="foo", schema=TYPED_PROTOBUF_SCHEMA, subject_type=subject_type, naming_strategy=strategy) == expected_subject ) + + +def test_avro_reference() -> None: + country_schema = ValidatedTypedSchema.parse( + schema_type=SchemaType.AVRO, + schema_str=textwrap.dedent( + """\ + { + "type": "record", + "name": "Country", + "namespace": "com.netapp", + "fields": [{"name": "name", "type": "string"}, {"name": "code", "type": "string"}] + } + """ + ), + ) + address_schema = ValidatedTypedSchema.parse( + schema_type=SchemaType.AVRO, + schema_str=textwrap.dedent( + """\ + { + "type": "record", + "name": "Address", + "namespace": "com.netapp", + "fields": [ + {"name": "street", "type": "string"}, + {"name": "city", "type": "string"}, + {"name": "postalCode", "type": "string"}, + {"name": "country", "type": "Country"} + ] + } + """ + ), + references=[Reference(name="country.avsc", subject=Subject("country"), version=Version(1))], + dependencies={ + "country": Dependency( + name="country", + subject=Subject("country"), + version=Version(1), + target_schema=country_schema, + ), + }, + ) + + # Check that the reference schema (Country) has been inlined + assert address_schema.schema_wrapped == textwrap.dedent( + """\ + { + "type": "record", + "name": "Address", + "namespace": "com.netapp", + "fields": [ + { + "type": "string", + "name": "street" + }, + { + "type": "string", + "name": "city" + }, + { + "type": "string", + "name": "postalCode" + }, + { + "type": { + "type": "record", + "name": "Country", + "namespace": "com.netapp", + "fields": [ + { + "type": "string", + "name": "name" + }, + { + "type": "string", + "name": "code" + } + ] + }, + "name": "country" + } + ] + } + """ + )