Skip to content

Commit

Permalink
Merge branch 'main' into fixes_for_sqlmodel
Browse files Browse the repository at this point in the history
  • Loading branch information
DeltaDaniel authored Jul 30, 2024
2 parents 0bfa650 + f17a391 commit d16f6cd
Show file tree
Hide file tree
Showing 116 changed files with 2,426 additions and 701 deletions.
12 changes: 9 additions & 3 deletions src/bo4e_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@

import click

from bo4e_generator.parser import OutputType, bo4e_init_file_content, bo4e_version_file_content, parse_bo4e_schemas
from bo4e_generator.parser import (
OutputType,
bo4e_init_file_content,
bo4e_version_file_content,
get_formatter,
parse_bo4e_schemas,
)
from bo4e_generator.schema import get_namespace, get_version
from bo4e_generator.sqlparser import format_code


def resolve_paths(input_directory: Path, output_directory: Path) -> tuple[Path, Path]:
Expand Down Expand Up @@ -43,10 +48,11 @@ def generate_bo4e_schemas(
if clear_output and output_directory.exists():
shutil.rmtree(output_directory)

formatter = get_formatter()
for relative_file_path, file_content in file_contents.items():
file_path = output_directory / relative_file_path
file_path.parent.mkdir(parents=True, exist_ok=True)
file_content = format_code(file_content)
file_content = formatter.format_code(file_content)
file_path.write_text(file_content, encoding="utf-8")
print(f"Created {file_path}")
print("Done.")
Expand Down
129 changes: 129 additions & 0 deletions src/bo4e_generator/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Contains monkey patches related to imports
"""

import inspect
from collections import defaultdict
from typing import Iterable, List, Optional, Set, Union

from datamodel_code_generator.imports import Import
from datamodel_code_generator.imports import Imports as _Imports

from bo4e_generator.schema import SchemaMetadata


# pylint: disable=too-many-statements
def monkey_patch_imports(namespace: dict[str, SchemaMetadata]):
"""
Overwrites the behaviour how imports are rendered. They are not going through jinja templates.
They Imports class has a __str__ method, which we will overwrite here.
"""
namespace = {k: v for k, v in namespace.items() if k not in ("Typ", "Landescode")}
# "Typ" and "Landescode" must not be wrapped inside the "if TYPE_CHECKING" block because they are used explicitly
# to set default values.
import_type_checking = Import.from_full_path("typing.TYPE_CHECKING")

# pylint: disable=missing-function-docstring
class Imports(_Imports):
"""
Re-implement some methods to customize the import rendering
"""

def __str__(self) -> str:
return self.dump()

def _set_alias(self, from_: Optional[str], imports: Set[str]) -> List[str]:
return [
f"{i} as {self.alias[from_][i]}" if i in self.alias[from_] and i != self.alias[from_][i] else i
for i in sorted(imports)
]

def create_line(self, from_: Optional[str], imports: Set[str]) -> str:
if from_:
return f"from {from_} import {', '.join(self._set_alias(from_, imports))}"
return "\n".join(f"import {i}" for i in self._set_alias(from_, imports))

def dump(self) -> str:
imports_type_checking = defaultdict(set)
imports_no_type_checking = defaultdict(set)
for from_, imports in self.items():
for import_ in imports:
if import_ in namespace:
imports_type_checking[from_].add(import_)
else:
imports_no_type_checking[from_].add(import_)
imports_dump = "\n".join(
self.create_line(from_, imports) for from_, imports in imports_no_type_checking.items()
)
if len(imports_type_checking) > 0:
imports_dump += "\n\n"
imports_dump += "if TYPE_CHECKING:\n "
imports_dump += "\n ".join(
self.create_line(from_, imports) for from_, imports in imports_type_checking.items()
)
return imports_dump

def append(self, imports: Union[Import, Iterable[Import], None]) -> None:
if imports:
if isinstance(imports, Import):
imports = [imports]
for import_ in imports:
if import_.reference_path:
self.reference_paths[import_.reference_path] = import_
if (
import_type_checking.from_ not in self
or import_type_checking.import_ not in self[import_type_checking.from_]
):
self.append(import_type_checking)
if "." in import_.import_:
self[None].add(import_.import_)
self.counter[(None, import_.import_)] += 1
else:
self[import_.from_].add(import_.import_)
self.counter[(import_.from_, import_.import_)] += 1
if import_.alias:
self.alias[import_.from_][import_.import_] = import_.alias

def remove(
self, imports: Union[Import, Iterable[Import]], __intended_type_checking_remove: bool = False
) -> None:
if isinstance(imports, Import): # pragma: no cover
imports = [imports]
for import_ in imports:
if not __intended_type_checking_remove and import_ == import_type_checking:
continue
if "." in import_.import_: # pragma: no cover
self.counter[(None, import_.import_)] -= 1
if self.counter[(None, import_.import_)] == 0: # pragma: no cover
self[None].remove(import_.import_)
if not self[None]:
del self[None]
else:
self.counter[(import_.from_, import_.import_)] -= 1 # pragma: no cover
if self.counter[(import_.from_, import_.import_)] == 0: # pragma: no cover
self[import_.from_].remove(import_.import_)
if not self[import_.from_]:
del self[import_.from_]
if import_.alias: # pragma: no cover
del self.alias[import_.from_][import_.import_]
if not self.alias[import_.from_]:
del self.alias[import_.from_]

if (
import_type_checking.from_ in self
and import_type_checking.import_ in self[import_type_checking.from_]
and not any(
imp_str in namespace for imp_str_sets in self.values() for imp_str in imp_str_sets
)
):
self.remove(
import_type_checking,
__intended_type_checking_remove=True, # type: ignore[call-arg]
)

def remove_referenced_imports(self, reference_path: str) -> None:
if reference_path in self.reference_paths:
self.remove(self.reference_paths[reference_path])

for name, func in inspect.getmembers(Imports, inspect.isfunction):
setattr(_Imports, name, func)
59 changes: 58 additions & 1 deletion src/bo4e_generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
import datamodel_code_generator.parser.base
import datamodel_code_generator.reference
from datamodel_code_generator import DataModelType, PythonVersion
from datamodel_code_generator.format import CodeFormatter
from datamodel_code_generator.imports import IMPORT_DATETIME
from datamodel_code_generator.model import DataModelSet, get_data_model_types
from datamodel_code_generator.model.enum import Enum as _Enum
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
from datamodel_code_generator.types import DataType, StrictTypes, Types

from bo4e_generator.imports import monkey_patch_imports
from bo4e_generator.schema import SchemaMetadata
from bo4e_generator.sqlparser import adapt_parse_for_sql, remove_pydantic_field_import, write_many_many_links

Expand Down Expand Up @@ -75,6 +77,26 @@ class BO4EDataTypeManager(data_model_types.data_type_manager): # type: ignore[n
featured in pydantic v2. Instead, the standard datetime type will be used.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

class DataTypeWithForwardRef(self.data_type):
"""
Override the data type to replace explicit type references with forward references if the type
is present in namespace.
Also, the AwareDateTime import is replaced with the standard datetime import.
"""

@property
def type_hint(self) -> str:
"""Return the type hint for the data type."""
type_ = super().type_hint
if self.reference and type_ in namespace:
type_ = f'"{type_}"'
return type_

self.data_type = DataTypeWithForwardRef

def type_map_factory(
self,
data_type: Type[DataType],
Expand All @@ -86,6 +108,8 @@ def type_map_factory(
result[Types.date_time] = data_type.from_import(IMPORT_DATETIME)
return result

monkey_patch_imports(namespace)

return DataModelSet(
data_model=BO4EDataModel,
root_model=data_model_types.root_model,
Expand Down Expand Up @@ -172,6 +196,16 @@ def bo4e_init_file_content(namespace: dict[str, SchemaMetadata], version: str) -
init_file_content += f"from .{'.'.join(schema_metadata.module_path)} import {schema_metadata.class_name}\n"
init_file_content += "\nfrom .__version__ import __version__\n"

init_file_content += (
"from pydantic import BaseModel as _PydanticBaseModel\n"
"\n\n# Resolve all ForwardReferences. This design prevents circular import errors.\n"
"for cls_name in __all__:\n"
" cls = globals().get(cls_name, None)\n"
" if cls is None or not isinstance(cls, type) or not issubclass(cls, _PydanticBaseModel):\n"
" continue\n"
" cls.model_rebuild(force=True)\n"
)

return init_file_content


Expand All @@ -182,6 +216,13 @@ def remove_future_import(python_code: str) -> str:
return re.sub(r"from __future__ import annotations\n\n", "", python_code)


def remove_model_rebuild(python_code: str, class_name: str) -> str:
"""
Remove the model_rebuild call from the generated code.
"""
return re.sub(rf"{class_name}\.model_rebuild\(\)\n", "", python_code)


def parse_bo4e_schemas(
input_directory: Path, namespace: dict[str, SchemaMetadata], output_type: OutputType
) -> dict[Path, str]:
Expand Down Expand Up @@ -218,7 +259,7 @@ def parse_bo4e_schemas(
use_schema_description=True,
use_subclass_enum=True,
use_standard_collections=True,
use_union_operator=True,
use_union_operator=False,
use_field_description=True,
set_default_enum_member=True,
snake_case_field=True,
Expand Down Expand Up @@ -251,6 +292,7 @@ def parse_bo4e_schemas(
)

python_code = remove_future_import(parse_result.pop(module_path).body)
python_code = remove_model_rebuild(python_code, schema_metadata.class_name)
if output_type is OutputType.SQL_MODEL.name:
# remove pydantic field
python_code = remove_pydantic_field_import(python_code)
Expand All @@ -266,3 +308,18 @@ def parse_bo4e_schemas(
file_contents[Path("many.py")] = write_many_many_links(links)

return file_contents


def get_formatter() -> CodeFormatter:
"""
Returns a formatter to apply black and isort
"""
return CodeFormatter(
PythonVersion.PY_311,
None,
None,
skip_string_normalization=False,
known_third_party=None,
custom_formatters=None,
custom_formatters_kwargs=None,
)
4 changes: 3 additions & 1 deletion unittests/test_data/bo4e_schemas/ZusatzAttribut.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
{
"description": "Viele Datenobjekte weisen in unterschiedlichen Systemen eine eindeutige ID (Kundennummer, GP-Nummer etc.) auf.\nBeim Austausch von Datenobjekten zwischen verschiedenen Systemen ist es daher hilfreich,\nsich die eindeutigen IDs der anzubindenden Systeme zu merken.\n\n.. raw:: html\n\n <object data=\"../_static/images/bo4e/com/ZusatzAttribut.svg\" type=\"image/svg+xml\"></object>\n\n.. HINT::\n `ZusatzAttribut JSON Schema <https://json-schema.app/view/%23?url=https://raw.githubusercontent.com/Hochfrequenz/BO4E-Schemas/v202401.0.1-/src/bo4e_schemas/ZusatzAttribut.json>`_",
"description": "Viele Datenobjekte weisen in unterschiedlichen Systemen eine eindeutige ID (Kundennummer, GP-Nummer etc.) auf.\nBeim Austausch von Datenobjekten zwischen verschiedenen Systemen ist es daher hilfreich,\nsich die eindeutigen IDs der anzubindenden Systeme zu merken.\n\n.. raw:: html\n\n <object data=\"../_static/images/bo4e/com/ZusatzAttribut.svg\" type=\"image/svg+xml\"></object>\n\n.. HINT::\n `ZusatzAttribut JSON Schema <https://json-schema.app/view/%23?url=https://raw.githubusercontent.com/BO4E/BO4E-Schemas/v202401.3.2/src/bo4e_schemas/ZusatzAttribut.json>`_",
"title": "ZusatzAttribut",
"properties": {
"name": {
"description": "Bezeichnung der externen Referenz (z.B. \"microservice xyz\" oder \"SAP CRM GP-Nummer\")",
"title": "Name",
"anyOf": [
{
Expand All @@ -14,6 +15,7 @@
]
},
"wert": {
"description": "Bezeichnung der externen Referenz (z.B. \"microservice xyz\" oder \"SAP CRM GP-Nummer\")",
"title": "Wert"
}
},
Expand Down
Loading

0 comments on commit d16f6cd

Please sign in to comment.