Skip to content

Commit

Permalink
feature: Support non-array common struct fields (#78)
Browse files Browse the repository at this point in the history
Building schema with upstream trunk was failing due to lacking support
for non-array common struct fields.
  • Loading branch information
aiven-anton authored Sep 10, 2023
1 parent eee5dc2 commit 8af236e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
41 changes: 36 additions & 5 deletions codegen/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .case import to_snake_case
from .header_schema import get_header_schema_import
from .parser import CommonStructArrayField
from .parser import CommonStructField
from .parser import CommonStructType
from .parser import DataSchema
from .parser import EntityArrayField
from .parser import EntityField
Expand Down Expand Up @@ -137,7 +139,7 @@ def format_default(


def format_dataclass_field(
field_type: Primitive | PrimitiveArrayType | EntityType,
field_type: Primitive | PrimitiveArrayType | EntityType | CommonStructType,
default: str | int | float | bool | None,
optional: bool,
custom_type: CustomTypeDef | None,
Expand Down Expand Up @@ -166,6 +168,7 @@ def format_dataclass_field(
field_kwargs["default"] = "()"
elif default is not None:
assert not isinstance(field_type, EntityType)
assert not isinstance(field_type, CommonStructType)
field_kwargs["default"] = format_default(
field_type, default, optional, custom_type
)
Expand Down Expand Up @@ -310,7 +313,10 @@ def generate_entity_array_field(
return f" {to_snake_case(field.name)}: tuple[{field.type}, ...]{field_call}\n"


def generate_entity_field(field: EntityField, version: int) -> str:
def generate_entity_field(
field: EntityField | CommonStructField,
version: int,
) -> str:
field_call = format_dataclass_field(
field_type=field.type,
default=None,
Expand All @@ -330,11 +336,28 @@ def generate_common_struct_array_field(
) -> str:
field_call = format_array_field_call(field, version)
return (
f" {to_snake_case(field.name)}: tuple[{field.type.item_type.name}, ...]"
f" {to_snake_case(field.name)}: tuple[{field.type.struct.name}, ...]"
f"{field_call}\n"
)


def generate_common_struct_field(
field: CommonStructField,
version: int,
) -> str:
field_call = format_dataclass_field(
field_type=field.type,
default=None,
optional=(
field.nullableVersions.matches(version) if field.nullableVersions else False
),
custom_type=None,
tag=field.get_tag(version),
ignorable=field.ignorable,
)
return f" {to_snake_case(field.name)}: {field.type.struct.name}{field_call}\n"


seen = set[tuple[str, int]]()


Expand Down Expand Up @@ -442,11 +465,19 @@ class {name}:
elif isinstance(field, CommonStructArrayField):
yield from generate_dataclass(
schema=schema,
name=field.type.item_type.name,
fields=field.type.item_type.fields,
name=field.type.struct.name,
fields=field.type.struct.fields,
version=version,
)
class_fields.append(generate_common_struct_array_field(field, version))
elif isinstance(field, CommonStructField):
yield from generate_dataclass(
schema=schema,
name=field.type.struct.name,
fields=field.type.struct.fields,
version=version,
)
class_fields.append(generate_common_struct_field(field, version))
else:
assert_never(field)

Expand Down
40 changes: 29 additions & 11 deletions codegen/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,18 @@ def parse_primitive_array_type(value: object) -> PrimitiveArrayType:
yield parse_primitive_array_type


def parse_common_struct_reference(struct_name: object) -> CommonStruct:
if not isinstance(struct_name, str):
raise ValueError("Common struct reference must be str")

try:
return structs_registry[struct_name]
except KeyError:
raise ValueError(f"No registered common struct named {struct_name!r}") from None


class CommonStructArrayType(NamedTuple):
item_type: CommonStruct
struct: CommonStruct

@classmethod
def __get_validators__(cls) -> Iterator[Callable[[object], CommonStructArrayType]]:
Expand All @@ -146,21 +156,25 @@ def parse_common_struct_array_type(value: object) -> CommonStructArrayType:
raise ValueError("CommonStructArrayType must be str")
if not value.startswith("[]"):
raise ValueError("CommonStructArrayType must start with '[]'")

struct_name = value.removeprefix("[]")

try:
struct = structs_registry[struct_name]
except KeyError:
raise ValueError(
f"No registered common struct named {struct_name!r}"
) from None

struct = parse_common_struct_reference(struct_name)
return CommonStructArrayType(struct)

yield parse_common_struct_array_type


class CommonStructType(NamedTuple):
struct: CommonStruct

@classmethod
def __get_validators__(cls) -> Iterator[Callable[[object], CommonStructType]]:
def parse_common_struct_type(struct_name: object) -> CommonStructType:
struct = parse_common_struct_reference(struct_name)
return CommonStructType(struct)

yield parse_common_struct_type


class _BaseField(BaseModel):
name: str
versions: VersionRange
Expand Down Expand Up @@ -222,7 +236,7 @@ def get_tag(self, version: int) -> int | None:

# Defining this union before its members allows not having to call
# EntityField.update_forward_refs().
Field: TypeAlias = "PrimitiveField | PrimitiveArrayField | EntityArrayField | CommonStructArrayField | EntityField"
Field: TypeAlias = "PrimitiveField | PrimitiveArrayField | EntityArrayField | CommonStructArrayField | EntityField | CommonStructField"
timedelta_names: Final = frozenset(
{
"timeoutMs",
Expand Down Expand Up @@ -343,6 +357,10 @@ class CommonStructArrayField(_BaseField):
type: CommonStructArrayType


class CommonStructField(_BaseField):
type: CommonStructType


class EntityArrayField(_BaseField):
type: EntityArrayType
fields: tuple[Field, ...]
Expand Down

0 comments on commit 8af236e

Please sign in to comment.