Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor fixes for ibims orm creation #99

Merged
merged 13 commits into from
Sep 2, 2024
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
dependencies = [
"datamodel-code-generator",
"click",
"autoflake"
] # add all the dependencies from requirements.in here, too
dynamic = ["readme", "version"]

Expand Down
2 changes: 0 additions & 2 deletions requirements.in

This file was deleted.

10 changes: 8 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile requirements.in
# pip-compile '.\pyproject.toml'
#
annotated-types==0.6.0
# via pydantic
argcomplete==3.1.2
# via datamodel-code-generator
autoflake==2.3.1
# via BO4E-Python-Generator (pyproject.toml)
black==24.8.0
# via datamodel-code-generator
click==8.1.7
# via
# -r requirements.in
# BO4E-Python-Generator (pyproject.toml)
# black
datamodel-code-generator==0.25.9
# via -r requirements.in
colorama==0.4.6
# via click
dnspython==2.4.2
# via email-validator
email-validator==2.0.0.post2
Expand Down Expand Up @@ -46,6 +50,8 @@ pydantic[email]==2.4.2
# via datamodel-code-generator
pydantic-core==2.10.1
# via pydantic
pyflakes==3.2.0
# via autoflake
pyyaml==6.0.1
# via datamodel-code-generator
typing-extensions==4.8.0
Expand Down
6 changes: 6 additions & 0 deletions src/bo4e_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
parse_bo4e_schemas,
)
from bo4e_generator.schema import get_namespace, get_version
from bo4e_generator.sqlparser import remove_unused_imports


def resolve_paths(input_directory: Path, output_directory: Path) -> tuple[Path, Path]:
Expand Down Expand Up @@ -52,6 +53,11 @@ def generate_bo4e_schemas(
for relative_file_path, file_content in file_contents.items():
file_path = output_directory / relative_file_path
file_path.parent.mkdir(parents=True, exist_ok=True)
if (
relative_file_path.name not in ["__init__.py", "__version__.py"]
and OutputType[output_type] == OutputType.SQL_MODEL
):
file_content = remove_unused_imports(file_content)
file_content = formatter.format_code(file_content)
file_path.write_text(file_content, encoding="utf-8")
print(f"Created {file_path}")
Expand Down
22 changes: 16 additions & 6 deletions src/bo4e_generator/custom_templates/BaseModel.jinja2
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
{%- if SQL and SQL['imports']%}
{%- for class_name, module_path in SQL['imports'].items() %}
{%- for import_class_name, module_path in SQL['imports'].items() %}
{%- if module_path[:4] == 'enum'%}
from borm.models.{{ module_path }} import {{ class_name }}
from ..{{ module_path }} import {{ import_class_name }}
{%- elif module_path == 'Link'%}
from borm.models.many import {{ class_name }}
{% if class_name == 'ZusatzAttribut'%}
from .many import {{ import_class_name }}
{% else %}
from ..many import {{ import_class_name }}
{% endif %}
{%- else %}
from {{ module_path }} import {{ class_name }}
from {{ module_path }} import {{ import_class_name }}
{%- endif %}
{%- endfor -%}
{%- endif %}
{%- if SQL and SQL['relationimports']%}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{%- for class_name, module_path in SQL['relationimports'].items() %}
from borm.models.{{ module_path }} import {{ class_name }}
{% if class_name == 'ZusatzAttribut'%}
{%- for import_class_name, module_path in SQL['relationimports'].items() %}
from .{{ module_path }} import {{ import_class_name }}
{%- endfor -%}
{% else %}
{%- for import_class_name, module_path in SQL['relationimports'].items() %}
from ..{{ module_path }} import {{ import_class_name }}
{%- endfor -%}
{%- endif %}
{%- endif %}
{% for decorator in decorators -%}
{{ decorator }}
Expand Down
4 changes: 2 additions & 2 deletions src/bo4e_generator/custom_templates/ManyLinks.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class {{class1}}{{class2[1]}}Link(SQLModel, table=True):
"""
class linking m-n relation of tables {{class1}} and {{class2[0]}} for field {{ class2[1]}}.
"""
{{class1.lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class1.lower()}}.{{class1.lower()}}_sqlid", ondelete="CASCADE"), primary_key=True))
{{class2[0].lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class2[0].lower()}}.{{class2[0].lower()}}_sqlid", ondelete="CASCADE"), primary_key=True))
{{class1.lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class1.lower()}}.id", ondelete="CASCADE"), primary_key=True))
{{class2[0].lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class2[0].lower()}}.id", ondelete="CASCADE"), primary_key=True))

{%- endfor -%}
{%- endfor -%}
79 changes: 49 additions & 30 deletions src/bo4e_generator/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
"""

import json
import os
import re
import subprocess
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Any, DefaultDict, Union

import black
import isort
from jinja2 import Environment, FileSystemLoader

from bo4e_generator.schema import SchemaMetadata
from bo4e_generator.schema import SchemaMetadata, camel_to_snake


def remove_pydantic_field_import(python_code: str) -> str:
Expand Down Expand Up @@ -42,29 +43,32 @@ def adapt_parse_for_sql(
for schema_metadata in namespace.values():
if schema_metadata.module_path[0] != "enum":
# list of fields which will be replaced by modified versions
del_fields = []
del_fields = set()
for field, val in schema_metadata.schema_parsed["properties"].items():
# type Any field
if "type" not in str(val):
add_relation, relation_imports = create_sql_any(
field, schema_metadata.class_name, namespace, add_relation, relation_imports
)
del_fields.append(field)
del_fields.add(field)
# modify decimal fields
if "number" in str(val) and "string" in str(val):
relation_imports[schema_metadata.class_name + "ADD"]["Decimal"] = "decimal"
if "array" in str(val) and "$ref" not in str(val):
add_relation, relation_imports = create_sql_list(
field, schema_metadata.class_name, namespace, add_relation, relation_imports
)
del_fields.append(field)
del_fields.add(field)
if "$ref" in str(val): # or "array" in str(val):
add_relation, relation_imports = create_sql_field(
field, schema_metadata.class_name, namespace, add_relation, relation_imports
)
del_fields.append(field)
del_fields.add(field)
for field in del_fields:
del schema_metadata.schema_parsed["properties"][field]
# delete id field as it is replaced below
if schema_metadata.schema_parsed["properties"].get("_id"):
del schema_metadata.schema_parsed["properties"]["_id"]
# store the reduced version. The modified fields will be added in the BaseModel.jinja2 schema
schema_metadata.schema_text = json.dumps(schema_metadata.schema_parsed, indent=2, ensure_ascii=False)

Expand Down Expand Up @@ -104,9 +108,8 @@ def additional_sql_arguments(
if schema_metadata.module_path[0] != "enum":
# add primary key
additional_sql_data[schema_metadata.class_name]["SQL"] = {
"primary": schema_metadata.class_name.lower()
+ "_sqlid: uuid_pkg.UUID = Field( default_factory=uuid_pkg.uuid4, primary_key=True, index=True, "
"nullable=False )"
"primary": "id: uuid_pkg.UUID = Field( default_factory=uuid_pkg.uuid4, primary_key=True, index=True, "
'nullable=False, alias="_id", title=" Id" )'
}
if schema_metadata.class_name in add_relation:
additional_sql_data[schema_metadata.class_name]["SQL"]["relations"] = add_relation[
Expand Down Expand Up @@ -184,7 +187,7 @@ def create_sql_list(
add_imports[class_name + "ADD"]["Column, ARRAY"] = "sqlalchemy"
add_imports[class_name + "ADD"][sa_type] = "sqlalchemy"

add_fields[class_name][f"{field_name}"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
f"List[{type_hint}] "
+ is_optional
+ f' = Field({default}, title="{field_name}", sa_column=Column( ARRAY( {sa_type} )))'
Expand All @@ -209,12 +212,14 @@ def sql_reference_enum(
returns field which references enums.
"""
if is_list:
add_fields[class_name][f"{field_name}"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
f"List[{reference_name}]" + is_optional + f" = Field({default},"
f' sa_column=Column( ARRAY( Enum( {reference_name}, name="{reference_name.lower()}"))))'
)
else:
add_fields[class_name][f"{field_name}"] = f"{reference_name}" + is_optional + f"= Field({default})"
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
f"{reference_name}" + is_optional + f"= Field({default})"
)

# import enums
if is_list:
Expand Down Expand Up @@ -265,23 +270,23 @@ def create_sql_field(
add_fields["MANY"][class_name] = [[reference_name, field_name]]
elif reference_name not in add_fields["MANY"][class_name]:
add_fields["MANY"][class_name].append([reference_name, field_name])
add_fields[class_name][f"{field_name}"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
f'List["{reference_name}"] ='
f' Relationship(back_populates="{class_name.lower()}_{field_name.lower()}_link", '
f"link_model={class_name}{field_name}Link)"
)
add_fields[reference_name][f"{class_name.lower()}_{field_name.lower()}_link"] = (
f'List["{class_name}"] ='
f' Relationship(back_populates="{field_name}", '
f' Relationship(back_populates="{camel_to_snake(field_name)}", '
f"link_model={class_name}{field_name}Link)"
)
add_imports[class_name + "ADD"][f"{class_name}{field_name}Link)"] = "Link"
add_imports[reference_name + "ADD"][f"{class_name}{field_name}Link)"] = "Link"
else:
# cf. https://github.com/tiangolo/sqlmodel/pull/610
add_fields[class_name][f"{field_name}_id"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}_id"] = (
"uuid_pkg.UUID " + is_optional + f" = Field(sa_column=Column(UUID(as_uuid=True),"
f' ForeignKey("{reference_name.lower()}.{reference_name.lower()}_sqlid"'
f' ForeignKey("{reference_name.lower()}.id"'
f', ondelete="SET NULL")))'
)
add_imports[class_name + "ADD"]["Column"] = "sqlalchemy"
Expand All @@ -291,20 +296,20 @@ def create_sql_field(
# pylint: disable= fixme
# todo: check default

add_fields[class_name][f"{field_name}"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
f'"{reference_name}" ='
f' Relationship(back_populates="{class_name.lower()}_{field_name}",'
f' sa_relationship_kwargs= {{ "foreign_keys":"[{class_name}.{field_name}_id]" }})'
f' Relationship(back_populates="{class_name.lower()}_{camel_to_snake(field_name)}",'
f' sa_relationship_kwargs= {{ "foreign_keys":"[{class_name}.{camel_to_snake(field_name)}_id]" }})'
)

# cf. https://github.com/tiangolo/sqlmodel/issues/10
# https://github.com/tiangolo/sqlmodel/issues/213
# https://dev.to/whchi/disable-sqlmodel-foreign-key-constraint-55kp
add_fields[reference_name][f"{class_name.lower()}_{field_name}"] = (
f'List["{class_name}"] = Relationship(back_populates="{field_name}",'
add_fields[reference_name][f"{class_name.lower()}_{camel_to_snake(field_name)}"] = (
f'List["{class_name}"] = Relationship(back_populates="{camel_to_snake(field_name)}",'
f"sa_relationship_kwargs="
f'{{"primaryjoin":'
f' "{class_name}.{field_name}_id=={reference_name}.{reference_name.lower()}_sqlid",'
f' "{class_name}.{camel_to_snake(field_name)}_id=={reference_name}.id",'
f' "lazy": "joined"}})'
)
# add_relation_import
Expand Down Expand Up @@ -346,11 +351,11 @@ def create_sql_any(
if is_list:
add_imports[class_name + "ADD"]["List"] = "typing"
add_imports[class_name + "ADD"]["ARRAY"] = "sqlalchemy"
add_fields[class_name][f"{field_name}"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
"List[Any]" + is_optional + f" = Field({default}," f" sa_column=Column( ARRAY( PickleType)))"
)
else:
add_fields[class_name][f"{field_name}"] = (
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
"Any" + is_optional + f" = Field({default}," f" sa_column=Column( PickleType))"
)

Expand All @@ -365,13 +370,27 @@ def write_many_many_links(links: dict[str, str]) -> str:
environment = Environment(loader=FileSystemLoader(template_path))
template = environment.get_template("ManyLinks.jinja2")
python_code = template.render({"class": links})
python_code = format_code(python_code)
# python_code = format_code(python_code)
return python_code


def format_code(code: str) -> str:
def remove_unused_imports(code):
"""
perform isort and black on code
Removes unused imports from the given code using autoflake.
"""
code = black.format_str(code, mode=black.Mode())
return isort.code(code, known_local_folder=["borm"])
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmp_file:
tmp_file_name = tmp_file.name
tmp_file.write(code.encode("utf-8"))

# Run autoflake to remove unused imports
subprocess.run(["autoflake", "--remove-all-unused-imports", "--in-place", tmp_file_name], check=True)

# Read the cleaned code from the temporary file
with open(tmp_file_name, "r", encoding="utf-8") as tmp_file:
cleaned_code = tmp_file.read()

# Clean up the temporary file
os.remove(tmp_file_name)

return cleaned_code
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ deps =
pre-commit
commands =
python -m pip install --upgrade pip
pip-compile requirements.in
pip-compile .\pyproject.toml
pip install -r requirements.txt
pre-commit install

Expand Down
14 changes: 0 additions & 14 deletions unittests/test_sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from bo4e_generator.sqlparser import (
adapt_parse_for_sql,
create_sql_field,
format_code,
remove_pydantic_field_import,
return_ref,
write_many_many_links,
Expand Down Expand Up @@ -76,16 +75,3 @@ def test_write_many_many_links(self) -> None:
file_contents = write_many_many_links(links)
keywords = ["AngebotzusatzAttributeLink", "angebot_id", "zusatzattribut_id"]
assert all(substring in file_contents for substring in keywords)

def test_format_code(self) -> None:
unsorted = (
"from sqlmodel import Field, Relationship, SQLModel\n"
"from typing import TYPE_CHECKING, List\n"
"from borm.models.enum.anrede import Anrede"
)
resorted = (
"from typing import TYPE_CHECKING, List\n\n"
"from sqlmodel import Field, Relationship, SQLModel\n\n"
"from borm.models.enum.anrede import Anrede\n"
)
assert resorted == format_code(unsorted)