From c4020ef52951b7a7251ed37da82df61bffa9a854 Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Fri, 31 Jan 2025 04:03:51 -0500 Subject: [PATCH] Update gel-orm generators. Make generated models sorted by name so that the output is stable and does not change when the schema does not change. Expose arrays as valid scalars for SQLAlchemy and Django. Don't attempt to reflect computeds because we don't have then in SQL. Reformat to avoid long lines. --- gel/orm/django/generator.py | 45 +++++-- gel/orm/introspection.py | 4 +- gel/orm/sqla.py | 221 +++++++++++++++++++++------------- gel/orm/sqlmodel.py | 182 +++++++++++++++++----------- tests/dbsetup/base.edgeql | 10 ++ tests/dbsetup/base.esdl | 16 ++- tests/dbsetup/sqlmodel.edgeql | 10 ++ tests/dbsetup/sqlmodel.esdl | 11 ++ tests/test_django_basic.py | 59 +++++++++ tests/test_sqla_basic.py | 71 +++++++++++ tests/test_sqlmodel_basic.py | 75 ++++++++++++ 11 files changed, 541 insertions(+), 163 deletions(-) diff --git a/gel/orm/django/generator.py b/gel/orm/django/generator.py index 6ce63aea..bfcd1328 100644 --- a/gel/orm/django/generator.py +++ b/gel/orm/django/generator.py @@ -22,9 +22,9 @@ # values are controlled in Django via settings (USE_TZ) and are mutually # exclusive in the same app under default circumstances. 'std::datetime': 'DateTimeField', - 'cal::local_date': 'DateField', - 'cal::local_datetime': 'DateTimeField', - 'cal::local_time': 'TimeField', + 'std::cal::local_date': 'DateField', + 'std::cal::local_datetime': 'DateTimeField', + 'std::cal::local_time': 'TimeField', # all kinds of durations are not supported due to this error: # iso_8601 intervalstyle currently not supported } @@ -38,6 +38,8 @@ # from django.db import models +from django.contrib.postgres import fields as pgf + class GelUUIDField(models.UUIDField): # This field must be treated as a auto-generated UUID. @@ -55,6 +57,17 @@ class GelPGMeta: ''' CLOSEPAR_RE = re.compile(r'\)(?=\s+#|$)') +ARRAY_RE = re.compile(r'^array<(?P.+)>$') +NAME_RE = re.compile(r'^(?P\w+?)(?P\d*)$') + + +def field_name_sort(item): + key, val = item + + match = NAME_RE.fullmatch(key) + res = (match.group('alpha'), int(match.group('num') or -1)) + + return res class ModelClass(object): @@ -166,11 +179,11 @@ def build_models(self, maps): mod.links['source'] = ( f"LTForeignKey({source!r}, models.DO_NOTHING, " - f"db_column='source', primary_key=True)" + f"db_column='source')" ) mod.links['target'] = ( f"LTForeignKey({target!r}, models.DO_NOTHING, " - f"db_column='target')" + f"db_column='target', primary_key=True)" ) # Update the source model with the corresponding @@ -197,6 +210,12 @@ def render_prop(self, prop): req = 'blank=True, null=True' target = prop['target']['name'] + is_array = False + match = ARRAY_RE.fullmatch(target) + if match: + is_array = True + target = match.group('el') + try: ftype = GEL_SCALAR_MAP[target] except KeyError: @@ -206,7 +225,10 @@ def render_prop(self, prop): ) return '' - return f'models.{ftype}({req})' + if is_array: + return f'pgf.ArrayField(models.{ftype}({req}))' + else: + return f'models.{ftype}({req})' def render_link(self, link, bklink=None): if link['required']: @@ -267,7 +289,7 @@ def render_models(self, spec): self.out = f self.write(BASE_STUB) - for mod in modmap.values(): + for mod in sorted(modmap.values(), key=lambda x: x.name): self.write() self.write() self.render_model_class(mod) @@ -284,19 +306,22 @@ def render_model_class(self, mod): if mod.props: self.write() self.write(f'# properties as Fields') - for name, val in mod.props.items(): + props = sorted(mod.props.items(), key=field_name_sort) + for name, val in props: self.write(f'{name} = {val}') if mod.links: self.write() self.write(f'# links as ForeignKeys') - for name, val in mod.links.items(): + links = sorted(mod.links.items(), key=field_name_sort) + for name, val in links: self.write(f'{name} = {val}') if mod.mlinks: self.write() self.write(f'# multi links as ManyToManyFields') - for name, val in mod.mlinks.items(): + mlinks = sorted(mod.mlinks.items(), key=field_name_sort) + for name, val in mlinks: self.write(f'{name} = {val}') if '.' not in mod.table: diff --git a/gel/orm/introspection.py b/gel/orm/introspection.py index ec43464f..39e436dc 100644 --- a/gel/orm/introspection.py +++ b/gel/orm/introspection.py @@ -31,7 +31,7 @@ ), target: {name}, }, - } filter .name != '__type__', + } filter .name != '__type__' and not exists .expr, properties: { name, readonly, @@ -42,7 +42,7 @@ filter .name = 'std::exclusive' ), target: {name}, - }, + } filter not exists .expr, backlinks := >[], } filter diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py index 9af5172e..f9949383 100644 --- a/gel/orm/sqla.py +++ b/gel/orm/sqla.py @@ -9,48 +9,64 @@ GEL_SCALAR_MAP = { - 'std::bool': ('bool', 'Boolean'), - 'std::str': ('str', 'String'), - 'std::int16': ('int', 'Integer'), - 'std::int32': ('int', 'Integer'), - 'std::int64': ('int', 'Integer'), - 'std::float32': ('float', 'Float'), - 'std::float64': ('float', 'Float'), - 'std::uuid': ('uuid.UUID', 'Uuid'), + 'std::bool': ('bool', 'sa.Boolean'), + 'std::str': ('str', 'sa.String'), + 'std::int16': ('int', 'sa.Integer'), + 'std::int32': ('int', 'sa.Integer'), + 'std::int64': ('int', 'sa.Integer'), + 'std::float32': ('float', 'sa.Float'), + 'std::float64': ('float', 'sa.Float'), + 'std::uuid': ('uuid.UUID', 'sa.Uuid'), + 'std::bytes': ('bytes', 'sa.LargeBinary'), + 'std::cal::local_date': ('datetime.date', 'sa.Date'), + 'std::cal::local_time': ('datetime.time', 'sa.Time'), + 'std::cal::local_datetime': ('datetime.datetime', 'sa.DateTime'), + 'std::datetime': ('datetime.datetime', 'sa.TIMESTAMP'), } -CLEAN_RE = re.compile(r'[^A-Za-z0-9]+') +ARRAY_RE = re.compile(r'^array<(?P.+)>$') +NAME_RE = re.compile(r'^(?P\w+?)(?P\d*)$') COMMENT = '''\ # # Automatically generated from Gel schema. +# +# Do not edit directly as re-generating this file will overwrite any changes. #\ ''' BASE_STUB = f'''\ {COMMENT} -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import orm as orm -class Base(DeclarativeBase): +class Base(orm.DeclarativeBase): pass\ ''' MODELS_STUB = f'''\ {COMMENT} +import datetime import uuid -from typing import List -from typing import Optional +from typing import List, Optional -from sqlalchemy import MetaData, Table, Column, ForeignKey -from sqlalchemy import String, Uuid, Integer, Float, Boolean -from sqlalchemy.orm import Mapped, mapped_column, relationship +import sqlalchemy as sa +from sqlalchemy import orm as orm ''' +def field_name_sort(spec): + key = spec['name'] + + match = NAME_RE.fullmatch(key) + res = (match.group('alpha'), int(match.group('num') or -1)) + + return res + + class ModelGenerator(FilePrinter): def __init__(self, *, outdir=None, basemodule=None): # set the output to be stdout by default, but this is generally @@ -118,9 +134,9 @@ def init_module(self, mod, modules): def get_fk(self, mod, table, curmod): if mod == curmod: # No need for anything fancy within the same schema - return f'ForeignKey("{table}.id")' + return f'sa.ForeignKey("{table}.id")' else: - return f'ForeignKey("{mod}.{table}.id")' + return f'sa.ForeignKey("{mod}.{table}.id")' def get_py_name(self, mod, name, curmod): if False and mod == curmod: @@ -179,7 +195,8 @@ def render_models(self, spec): self.write(MODELS_STUB) self.write(f'from ._sqlabase import Base') - for rec in spec['link_tables']: + link_tables = sorted(spec['link_tables'], key=lambda x: x['name']) + for rec in link_tables: self.write() self.render_link_table(rec) @@ -189,11 +206,19 @@ def render_models(self, spec): # skip apparently empty modules continue - for lobj in maps.get('link_objects', {}).values(): + link_objects = sorted( + maps.get('link_objects', {}).values(), + key=lambda x: x['name'], + ) + for lobj in link_objects: self.write() self.render_link_object(lobj, modules) - for rec in maps.get('object_types', {}).values(): + object_types = sorted( + maps.get('object_types', {}).values(), + key=lambda x: x['name'], + ) + for rec in object_types: self.write() self.render_type(rec, modules) @@ -204,13 +229,13 @@ def render_link_table(self, spec): t_fk = self.get_fk(tmod, target, 'default') self.write() - self.write(f'{spec["name"]} = Table(') + self.write(f'{spec["name"]} = sa.Table(') self.indent() self.write(f'{spec["table"]!r},') self.write(f'Base.metadata,') # source is in the same module as this table - self.write(f'Column("source", {s_fk}),') - self.write(f'Column("target", {t_fk}),') + self.write(f'sa.Column("source", {s_fk}),') + self.write(f'sa.Column("target", {t_fk}),') self.write(f'schema={mod!r},') self.dedent() self.write(f')') @@ -243,10 +268,13 @@ def render_link_object(self, spec, modules): tmod, target = get_mod_and_name(link['target']['name']) fk = self.get_fk(tmod, target, mod) pyname = self.get_py_name(tmod, target, mod) - self.write(f'{lname}_id: Mapped[uuid.UUID] = mapped_column(') + self.write(f'{lname}_id: orm.Mapped[uuid.UUID] = orm.mapped_column(') self.indent() - self.write(f'{lname!r}, Uuid(), {fk},') - self.write(f'primary_key=True, nullable=False,') + self.write(f'{lname!r},') + self.write(f'sa.Uuid(),') + self.write(f'{fk},') + self.write(f'primary_key=True,') + self.write(f'nullable=False,') self.dedent() self.write(')') @@ -260,8 +288,8 @@ def render_link_object(self, spec, modules): ) self.write( - f'{lname}: Mapped[{pyname}] = ' - f'relationship(back_populates={bklink!r})' + f'{lname}: orm.Mapped[{pyname}] = ' + f'orm.relationship(back_populates={bklink!r})' ) if spec['properties']: @@ -292,25 +320,30 @@ def render_type(self, spec, modules): self.write() # Add two fields that all objects have - self.write(f'id: Mapped[uuid.UUID] = mapped_column(') + self.write(f'id: orm.Mapped[uuid.UUID] = orm.mapped_column(') self.indent() - self.write( - f"Uuid(), primary_key=True, server_default='uuid_generate_v4()')") + self.write(f"sa.Uuid(),") + self.write(f"primary_key=True,") + self.write(f"server_default='uuid_generate_v4()',") self.dedent() + self.write(f')') # This is maintained entirely by Gel, the server_default simply # indicates to SQLAlchemy that this value may be omitted. - self.write(f'gel_type_id: Mapped[uuid.UUID] = mapped_column(') + self.write(f'gel_type_id: orm.Mapped[uuid.UUID] = orm.mapped_column(') self.indent() - self.write( - f"'__type__', Uuid(), server_default='PLACEHOLDER')") + self.write(f"'__type__',") + self.write(f"sa.Uuid(),") + self.write(f"server_default='PLACEHOLDER',") self.dedent() + self.write(f")") if spec['properties']: self.write() self.write('# Properties:') - for prop in spec['properties']: + properties = sorted(spec['properties'], key=field_name_sort) + for prop in properties: if prop['name'] != 'id': self.render_prop(prop, mod, name, modules) @@ -318,14 +351,16 @@ def render_type(self, spec, modules): self.write() self.write('# Links:') - for link in spec['links']: + links = sorted(spec['links'], key=field_name_sort) + for link in links: self.render_link(link, mod, name, modules) if spec['backlinks']: self.write() self.write('# Back-links:') - for link in spec['backlinks']: + backlinks = sorted(spec['backlinks'], key=field_name_sort) + for link in backlinks: self.render_backlink(link, mod, modules) self.dedent() @@ -336,8 +371,15 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): cardinality = spec['cardinality'] target = spec['target']['name'] + is_array = False + match = ARRAY_RE.fullmatch(target) + if match: + is_array = True + target = match.group('el') + try: pytype, sqlatype = GEL_SCALAR_MAP[target] + sqlatype = sqlatype + '()' except KeyError: warnings.warn( f'Scalar type {target} is not supported', @@ -346,23 +388,29 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): # Skip rendering this one return + if is_array: + pytype = f'List[{pytype}]' + sqlatype = f'sa.ARRAY({sqlatype})' + if is_pk: # special case of a primary key property (should only happen to # 'target' in multi property table) - self.write( - f'{name}: Mapped[{pytype}] = mapped_column(' - f'{sqlatype}(), primary_key=True, nullable=False)' - ) + self.write(f'{name}: orm.Mapped[{pytype}] = orm.mapped_column(') + self.indent() + self.write(f'{sqlatype}, primary_key=True, nullable=False,') + self.dedent() + self.write(f')') elif cardinality == 'Many': # skip it return else: # plain property - self.write( - f'{name}: Mapped[{pytype}] = ' - f'mapped_column({sqlatype}(), nullable={nullable})' - ) + self.write(f'{name}: orm.Mapped[{pytype}] = orm.mapped_column(') + self.indent() + self.write(f'{sqlatype}, nullable={nullable},') + self.dedent() + self.write(f')') def render_link(self, spec, mod, parent, modules): name = spec['name'] @@ -383,22 +431,22 @@ def render_link(self, spec, mod, parent, modules): if cardinality == 'One': self.write( - f'{name}: Mapped[{pyname}] = ' - f"relationship(back_populates='source')" + f'{name}: orm.Mapped[{pyname}] = ' + f"orm.relationship(back_populates='source')" ) elif cardinality == 'Many': self.write( - f'{name}: Mapped[List[{pyname}]] = ' - f"relationship(back_populates='source')" + f'{name}: orm.Mapped[List[{pyname}]] = ' + f"orm.relationship(back_populates='source')" ) if cardinality == 'One': - tmap = f'Mapped[{pyname}]' + tmap = f'orm.Mapped[{pyname}]' elif cardinality == 'Many': - tmap = f'Mapped[List[{pyname}]]' + tmap = f'orm.Mapped[List[{pyname}]]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without source. - self.write(f'{name}: {tmap} = relationship(') + self.write(f'{name}: {tmap} = orm.relationship(') self.indent() self.write(f"back_populates='source',") self.write(f"cascade='all, delete-orphan',") @@ -411,26 +459,29 @@ def render_link(self, spec, mod, parent, modules): if cardinality == 'One': self.write( - f'{name}_id: Mapped[uuid.UUID] = ' - f'mapped_column(Uuid(), ' - f'{fk}, nullable={nullable})' - ) - self.write( - f'{name}: Mapped[{pyname}] = ' - f'relationship(back_populates={bklink!r})' - ) + f'{name}_id: orm.Mapped[uuid.UUID] = orm.mapped_column(') + self.indent() + self.write(f'sa.Uuid(), {fk}, nullable={nullable},') + self.dedent() + self.write(f')') + + self.write(f'{name}: orm.Mapped[{pyname}] = orm.relationship(') + self.indent() + self.write(f'back_populates={bklink!r},') + self.dedent() + self.write(f')') elif cardinality == 'Many': secondary = f'{parent}_{name}_table' + self.write( - f'{name}: Mapped[List[{pyname}]] = relationship(') + f'{name}: orm.Mapped[List[{pyname}]] = orm.relationship(') self.indent() - self.write( - f'{pyname}, secondary={secondary}, ' - f'back_populates={bklink!r},' - ) + self.write(f'{pyname},') + self.write(f'secondary={secondary},') + self.write(f'back_populates={bklink!r},') self.dedent() - self.write(')') + self.write(f')') def render_backlink(self, spec, mod, modules): name = spec['name'] @@ -449,12 +500,12 @@ def render_backlink(self, spec, mod, modules): pyname = self.get_py_name(tmod, target, mod) if cardinality == 'One': - tmap = f'Mapped[{pyname}]' + tmap = f'orm.Mapped[{pyname}]' elif cardinality == 'Many': - tmap = f'Mapped[List[{pyname}]]' + tmap = f'orm.Mapped[List[{pyname}]]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without target. - self.write(f'{name}: {tmap} = relationship(') + self.write(f'{name}: {tmap} = orm.relationship(') self.indent() self.write(f"back_populates='target',") self.write(f"cascade='all, delete-orphan',") @@ -467,25 +518,29 @@ def render_backlink(self, spec, mod, modules): # This is a backlink from a single link. There is no link table # involved. if cardinality == 'One': - self.write( - f'{name}: Mapped[{pyname}] = ' - f'relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: orm.Mapped[{pyname}] = \\') + self.indent() + self.write(f'orm.relationship(back_populates={bklink!r})') + self.dedent() + elif cardinality == 'Many': - self.write( - f'{name}: Mapped[List[{pyname}]] = ' - f'relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: orm.Mapped[List[{pyname}]] = \\') + self.indent() + self.write(f'orm.relationship(back_populates={bklink!r})') + self.dedent() else: # This backlink involves a link table, so we still treat it as # a Many-to-Many. secondary = f'{target}_{bklink}_table' - self.write(f'{name}: Mapped[List[{pyname}]] = relationship(') + + self.write(f'{name}: orm.Mapped[List[{pyname}]] = \\') self.indent() - self.write( - f'{pyname}, secondary={secondary}, ' - f'back_populates={bklink!r},' - ) + self.write(f'orm.relationship(') + self.indent() + self.write(f'{pyname},') + self.write(f'secondary={secondary},') + self.write(f'back_populates={bklink!r},') self.dedent() self.write(')') + self.dedent() diff --git a/gel/orm/sqlmodel.py b/gel/orm/sqlmodel.py index f8deccce..8da16973 100644 --- a/gel/orm/sqlmodel.py +++ b/gel/orm/sqlmodel.py @@ -9,34 +9,52 @@ GEL_SCALAR_MAP = { - 'std::bool': 'bool', - 'std::str': 'str', - 'std::int16': 'int', - 'std::int32': 'int', - 'std::int64': 'int', - 'std::float32': 'float', - 'std::float64': 'float', - 'std::uuid': 'uuid.UUID', + 'std::bool': ('bool', None), + 'std::str': ('str', None), + 'std::int16': ('int', None), + 'std::int32': ('int', None), + 'std::int64': ('int', None), + 'std::float32': ('float', None), + 'std::float64': ('float', None), + 'std::uuid': ('uuid.UUID', None), + 'std::bytes': ('bytes', None), + 'std::cal::local_date': ('datetime.date', None), + 'std::cal::local_time': ('datetime.time', None), + 'std::cal::local_datetime': ('datetime.datetime', 'sa.DateTime()'), + 'std::datetime': ('datetime.datetime', 'sa.TIMESTAMP(timezone=True)'), } CLEAN_RE = re.compile(r'[^A-Za-z0-9]+') +NAME_RE = re.compile(r'^(?P\w+?)(?P\d*)$') COMMENT = '''\ # # Automatically generated from Gel schema. +# +# Do not edit directly as re-generating this file will overwrite any changes. #\ ''' MODELS_STUB = f'''\ {COMMENT} +import datetime import uuid -from sqlmodel import SQLModel, Field, Relationship -from sqlalchemy import Column, ForeignKey +import sqlmodel as sm +import sqlalchemy as sa ''' +def field_name_sort(spec): + key = spec['name'] + + match = NAME_RE.fullmatch(key) + res = (match.group('alpha'), int(match.group('num') or -1)) + + return res + + class ModelGenerator(FilePrinter): def __init__(self, *, outdir=None, basemodule=None): # set the output to be stdout by default, but this is generally @@ -105,9 +123,9 @@ def get_fk(self, mod, table, curmod): def get_sqla_fk(self, mod, table, curmod): if mod == curmod: # No need for anything fancy within the same schema - return f'ForeignKey("{table}.id")' + return f'sa.ForeignKey("{table}.id")' else: - return f'ForeignKey("{mod}.{table}.id")' + return f'sa.ForeignKey("{mod}.{table}.id")' def get_py_name(self, mod, name, curmod): if mod == curmod: @@ -182,10 +200,18 @@ def render_models(self, spec): # skip apparently empty modules return - for lobj in maps.get('link_objects', {}).values(): + link_objects = sorted( + maps.get('link_objects', {}).values(), + key=lambda x: x['name'] + ) + for lobj in link_objects: self.write() self.render_link_object(lobj, modules) + objects = sorted( + maps.get('object_types', {}).values(), + key=lambda x: x['name'] + ) for rec in maps.get('object_types', {}).values(): self.write() self.render_type(rec, modules) @@ -207,7 +233,7 @@ def render_link_table(self, spec): return self.write() - self.write(f'class {spec["name"]}(SQLModel, table=True):') + self.write(f'class {spec["name"]}(sm.SQLModel, table=True):') self.indent() self.write(f'__tablename__ = {spec["table"]!r}') if mod != 'default': @@ -218,8 +244,17 @@ def render_link_table(self, spec): self.write('__mapper_args__ = {"confirm_deleted_rows": False}') self.write() # source is in the same module as this table - self.write(f'source: uuid.UUID = Field({s_fk}, primary_key=True)') - self.write(f'target: uuid.UUID = Field({t_fk}, primary_key=True)') + self.write(f'source: uuid.UUID = sm.Field(') + self.indent() + self.write(f'{s_fk}, primary_key=True,') + self.dedent() + self.write(f')') + + self.write(f'target: uuid.UUID = sm.Field(') + self.indent() + self.write(f'{t_fk}, primary_key=True,') + self.dedent() + self.write(f')') self.dedent() def render_link_object(self, spec, modules): @@ -237,7 +272,7 @@ def render_link_object(self, spec, modules): return self.write() - self.write(f'class {name}(SQLModel, table=True):') + self.write(f'class {name}(sm.SQLModel, table=True):') self.indent() self.write(f'__tablename__ = {sql_name!r}') if mod != 'default': @@ -268,14 +303,15 @@ def render_link_object(self, spec, modules): fk = self.get_fk(tmod, target, mod) sqlafk = self.get_sqla_fk(tmod, target, mod) pyname = self.get_py_name(tmod, target, mod) - self.write(f'{lname}_id: uuid.UUID = Field(sa_column=Column(') + self.write( + f'{lname}_id: uuid.UUID = sm.Field(sa_column=sa.Column(') self.indent() self.write(f'{lname!r},') self.write(f'{sqlafk},') self.write(f'primary_key=True,') self.write(f'nullable=False,') self.dedent() - self.write('))') + self.write(f'))') if lname == 'source': bklink = source_link @@ -287,9 +323,11 @@ def render_link_object(self, spec, modules): ) self.write( - f'{lname}: {pyname} = ' - f'Relationship(back_populates={bklink!r})' - ) + f'{lname}: {pyname} = sm.Relationship(') + self.indent() + self.write(f'back_populates={bklink!r},') + self.dedent() + self.write(f')') if spec['properties']: self.write() @@ -315,7 +353,7 @@ def render_type(self, spec, modules): return self.write() - self.write(f'class {name}(SQLModel, table=True):') + self.write(f'class {name}(sm.SQLModel, table=True):') self.indent() self.write(f'__tablename__ = {sql_name!r}') if mod != 'default': @@ -327,10 +365,10 @@ def render_type(self, spec, modules): self.write() # Add two fields that all objects have - self.write(f'id: uuid.UUID | None = Field(') + self.write(f'id: uuid.UUID | None = sm.Field(') self.indent() - self.write( - f"default=None, primary_key=True,") + self.write(f"default=None,") + self.write(f"primary_key=True,") self.write( f"sa_column_kwargs=dict(server_default='uuid_generate_v4()'),") self.dedent() @@ -338,12 +376,11 @@ def render_type(self, spec, modules): # This is maintained entirely by Gel, the server_default simply # indicates to SQLAlchemy that this value may be omitted. - self.write(f'gel_type_id: uuid.UUID | None = Field(') + self.write(f'gel_type_id: uuid.UUID | None = sm.Field(') self.indent() + self.write(f"default=None,") self.write( - f"default=None,") - self.write( - f"sa_column=Column('__type__', server_default='PLACEHOLDER'),") + f"sa_column=sa.Column('__type__', server_default='PLACEHOLDER'),") self.dedent() self.write(')') @@ -351,7 +388,8 @@ def render_type(self, spec, modules): self.write() self.write('# Properties:') - for prop in spec['properties']: + properties = sorted(spec['properties'], key=field_name_sort) + for prop in properties: if prop['name'] != 'id': self.render_prop(prop, mod, name, modules) @@ -359,14 +397,16 @@ def render_type(self, spec, modules): self.write() self.write('# Links:') - for link in spec['links']: + links = sorted(spec['links'], key=field_name_sort) + for link in links: self.render_link(link, mod, name, modules) if spec['backlinks']: self.write() self.write('# Back-links:') - for link in spec['backlinks']: + backlinks = sorted(spec['backlinks'], key=field_name_sort) + for link in backlinks: self.render_backlink(link, mod, modules) self.dedent() @@ -378,7 +418,8 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): target = spec['target']['name'] try: - pytype = GEL_SCALAR_MAP[target] + pytype, sa_col = GEL_SCALAR_MAP[target] + except KeyError: warnings.warn( f'Scalar type {target} is not supported', @@ -387,21 +428,23 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): # Skip rendering this one return - if is_pk: - # special case of a primary key property (should only happen to - # 'target' in multi property table) - self.write( - f'{name}: {pytype} = Field(primary_key=True, nullable=False)' - ) - elif cardinality == 'Many': + if cardinality == 'Many': # skip it return else: # plain property - self.write( - f'{name}: {pytype} = Field(nullable={nullable})' - ) + if sa_col: + self.write(f'{name}: {pytype} = sm.Field(sa_column=sa.Column(') + self.indent() + self.write(f'{sa_col},') + self.write(f'nullable={nullable},') + self.dedent() + self.write(f'))') + else: + self.write( + f'{name}: {pytype} = sm.Field(nullable={nullable})' + ) def render_link(self, spec, mod, parent, modules): name = spec['name'] @@ -431,12 +474,12 @@ def render_link(self, spec, mod, parent, modules): if cardinality == 'One': self.write( f'{name}: {pyname} = ' - f"Relationship(back_populates='source')" + f"sm.Relationship(back_populates='source')" ) elif cardinality == 'Many': self.write( f'{name}: list[{pyname}] = ' - f"Relationship(back_populates='source')" + f"sm.Relationship(back_populates='source')" ) if cardinality == 'One': @@ -445,7 +488,7 @@ def render_link(self, spec, mod, parent, modules): tmap = f'list[{pyname}]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without source. - self.write(f'{name}: {tmap} = Relationship(') + self.write(f'{name}: {tmap} = sm.Relationship(') self.indent() self.write(f"back_populates='source',") self.write(f"cascade_delete=True,") @@ -457,19 +500,22 @@ def render_link(self, spec, mod, parent, modules): pyname = self.get_py_name(tmod, target, mod) if cardinality == 'One': - self.write( - f'{name}_id: uuid.UUID = Field({fk}, nullable={nullable})' - ) - self.write( - f'{name}: {pyname} = ' - f'Relationship(back_populates={bklink!r})' - ) + self.write(f'{name}_id: uuid.UUID = sm.Field(') + self.indent() + self.write(f'{fk},') + self.write(f'nullable={nullable},') + self.dedent() + self.write(')') + + self.write(f'{name}: {pyname} = sm.Relationship(') + self.indent() + self.write(f'back_populates={bklink!r},') + self.dedent() + self.write(')') elif cardinality == 'Many': secondary = f'{parent}_{name}_table' - self.write( - f'{name}: list[{pyname}] = Relationship(' - ) + self.write(f'{name}: list[{pyname}] = sm.Relationship(') self.indent() self.write(f'back_populates={bklink!r},') self.write(f'link_model={secondary},') @@ -506,7 +552,7 @@ def render_backlink(self, spec, mod, modules): tmap = f'list[{pyname}]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without target. - self.write(f'{name}: {tmap} = Relationship(') + self.write(f'{name}: {tmap} = sm.Relationship(') self.indent() self.write(f"back_populates='target',") self.write(f"cascade_delete=True,") @@ -519,22 +565,24 @@ def render_backlink(self, spec, mod, modules): # This is a backlink from a single link. There is no link table # involved. if cardinality == 'One': - self.write( - f'{name}: {pyname} = ' - f'Relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: {pyname} = sm.Relationship(') + self.indent() + self.write(f"back_populates={bklink!r},") + self.dedent() + self.write(')') elif cardinality == 'Many': - self.write( - f'{name}: list[{pyname}] = ' - f'Relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: list[{pyname}] = sm.Relationship(') + self.indent() + self.write(f"back_populates={bklink!r},") + self.dedent() + self.write(')') else: # This backlink involves a link table, so we still treat it as # a Many-to-Many. secondary = f'{target}_{bklink}_table' self.write( - f'{name}: list[{pyname}] = Relationship(' + f'{name}: list[{pyname}] = sm.Relationship(' ) self.indent() self.write(f'back_populates={bklink!r},') diff --git a/tests/dbsetup/base.edgeql b/tests/dbsetup/base.edgeql index cfbf46e8..582bb254 100644 --- a/tests/dbsetup/base.edgeql +++ b/tests/dbsetup/base.edgeql @@ -42,3 +42,13 @@ insert Post { author := assert_single((select User filter .name = 'Elsa')), body := '*magic stuff*', }; + +insert AssortedScalars { + name:= 'hello world', + vals := ['brown', 'fox'], + bstr := b'word\x00\x0b', + time := '20:13:45.678', + date:= '2025-01-26', + ts:='2025-01-26T20:13:45+00:00', + lts:='2025-01-26T20:13:45', +}; \ No newline at end of file diff --git a/tests/dbsetup/base.esdl b/tests/dbsetup/base.esdl index 4a5c02c0..735749b1 100644 --- a/tests/dbsetup/base.esdl +++ b/tests/dbsetup/base.esdl @@ -15,9 +15,23 @@ type GameSession { }; } -type User extending Named; +type User extending Named { + # test computed backlink + groups := .; + + date: cal::local_date; + time: cal::local_time; + ts: datetime; + lts: cal::local_datetime; + bstr: bytes; +} \ No newline at end of file diff --git a/tests/dbsetup/sqlmodel.edgeql b/tests/dbsetup/sqlmodel.edgeql index 20f699aa..2d7bee80 100644 --- a/tests/dbsetup/sqlmodel.edgeql +++ b/tests/dbsetup/sqlmodel.edgeql @@ -58,4 +58,14 @@ set { update HasLinkPropsB set { children += (select Child{@b := 'world'} filter .num = 1) +}; + +insert AssortedScalars { + name:= 'hello world', + vals := ['brown', 'fox'], + bstr := b'word\x00\x0b', + time := '20:13:45.678', + date:= '2025-01-26', + ts:='2025-01-26T20:13:45+00:00', + lts:='2025-01-26T20:13:45', }; \ No newline at end of file diff --git a/tests/dbsetup/sqlmodel.esdl b/tests/dbsetup/sqlmodel.esdl index 8e1347f4..d456dd81 100644 --- a/tests/dbsetup/sqlmodel.esdl +++ b/tests/dbsetup/sqlmodel.esdl @@ -38,4 +38,15 @@ type HasLinkPropsB { multi link children: Child { property b: str; } +} + +type AssortedScalars { + required name: str; + vals: array; + + date: cal::local_date; + time: cal::local_time; + ts: datetime; + lts: cal::local_datetime; + bstr: bytes; } \ No newline at end of file diff --git a/tests/test_django_basic.py b/tests/test_django_basic.py index 069f9c8f..488e0f4c 100644 --- a/tests/test_django_basic.py +++ b/tests/test_django_basic.py @@ -16,6 +16,7 @@ # limitations under the License. # +import datetime as dt import os import uuid import unittest @@ -299,6 +300,28 @@ def test_django_read_models_07(self): } ) + def test_django_read_models_08(self): + # test arrays, bytes and various date/time scalars + + res = self.m.AssortedScalars.objects.all()[0] + + self.assertEqual(res.name, 'hello world') + self.assertEqual(res.vals, ['brown', 'fox']) + self.assertEqual(bytes(res.bstr), b'word\x00\x0b') + self.assertEqual( + res.time, + dt.time(20, 13, 45, 678_000), + ) + self.assertEqual( + res.date, + dt.date(2025, 1, 26), + ) + # time zone aware (default for Django) + self.assertEqual( + res.ts, + dt.datetime.fromisoformat('2025-01-26T20:13:45+00:00'), + ) + def test_django_create_models_01(self): vals = self.m.User.objects.filter(name='Yvonne').all() self.assertEqual(list(vals), []) @@ -492,3 +515,39 @@ def test_django_update_models_04(self): post = self.m.Post.objects.get(id=post_id) self.assertEqual(post.author.name, 'Zoe') + + def test_django_update_models_05(self): + # test arrays, bytes and various date/time scalars + # + # For the purpose of sending data creating and updating a model are + # both testing accurate data transfer. + + res = self.m.AssortedScalars.objects.all()[0] + + res.name = 'New Name' + res.vals.append('jumped') + res.bstr = b'\x01success\x02' + res.time = dt.time(8, 23, 54, 999_000) + res.date = dt.date(2020, 2, 14) + res.ts = res.ts - dt.timedelta(days=6) + + res.save() + + upd = self.m.AssortedScalars.objects.all()[0] + + self.assertEqual(upd.name, 'New Name') + self.assertEqual(upd.vals, ['brown', 'fox', 'jumped']) + self.assertEqual(bytes(upd.bstr), b'\x01success\x02') + self.assertEqual( + upd.time, + dt.time(8, 23, 54, 999_000), + ) + self.assertEqual( + upd.date, + dt.date(2020, 2, 14), + ) + # time zone aware (default for Django) + self.assertEqual( + upd.ts, + dt.datetime.fromisoformat('2025-01-20T20:13:45+00:00'), + ) diff --git a/tests/test_sqla_basic.py b/tests/test_sqla_basic.py index 35934a6b..00394561 100644 --- a/tests/test_sqla_basic.py +++ b/tests/test_sqla_basic.py @@ -16,6 +16,7 @@ # limitations under the License. # +import datetime as dt import os import uuid import unittest @@ -344,6 +345,33 @@ def test_sqla_read_models_07(self): } ) + def test_sqla_read_models_08(self): + # test arrays, bytes and various date/time scalars + + res = self.sess.query(self.sm.AssortedScalars).one() + + self.assertEqual(res.name, 'hello world') + self.assertEqual(res.vals, ['brown', 'fox']) + self.assertEqual(res.bstr, b'word\x00\x0b') + self.assertEqual( + res.time, + dt.time(20, 13, 45, 678_000), + ) + self.assertEqual( + res.date, + dt.date(2025, 1, 26), + ) + # time zone aware + self.assertEqual( + res.ts, + dt.datetime.fromisoformat('2025-01-26T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + res.lts, + dt.datetime.fromisoformat('2025-01-26T20:13:45'), + ) + def test_sqla_create_models_01(self): vals = self.sess.query(self.sm.User).filter_by(name='Yvonne').all() self.assertEqual(list(vals), []) @@ -578,3 +606,46 @@ def test_sqla_update_models_04(self): post = self.sess.get(self.sm.Post, post_id) self.assertEqual(post.author.name, 'Zoe') + + def test_sqla_update_models_05(self): + # test arrays, bytes and various date/time scalars + # + # For the purpose of sending data creating and updating a model are + # both testing accurate data transfer. + + res = self.sess.query(self.sm.AssortedScalars).one() + + res.name = 'New Name' + res.vals.append('jumped') + res.bstr = b'\x01success\x02' + res.time = dt.time(8, 23, 54, 999_000) + res.date = dt.date(2020, 2, 14) + res.ts = res.ts - dt.timedelta(days=6) + res.lts = res.lts + dt.timedelta(days=6) + + self.sess.add(res) + self.sess.flush() + + upd = self.sess.query(self.sm.AssortedScalars).one() + + self.assertEqual(upd.name, 'New Name') + self.assertEqual(upd.vals, ['brown', 'fox', 'jumped']) + self.assertEqual(upd.bstr, b'\x01success\x02') + self.assertEqual( + upd.time, + dt.time(8, 23, 54, 999_000), + ) + self.assertEqual( + upd.date, + dt.date(2020, 2, 14), + ) + # time zone aware + self.assertEqual( + upd.ts, + dt.datetime.fromisoformat('2025-01-20T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + upd.lts, + dt.datetime.fromisoformat('2025-02-01T20:13:45'), + ) diff --git a/tests/test_sqlmodel_basic.py b/tests/test_sqlmodel_basic.py index ba8a1fc5..9a201535 100644 --- a/tests/test_sqlmodel_basic.py +++ b/tests/test_sqlmodel_basic.py @@ -16,6 +16,7 @@ # limitations under the License. # +import datetime as dt import os import uuid import unittest @@ -323,6 +324,34 @@ def test_sqlmodel_read_models_07(self): } ) + def test_sqlmodel_read_models_08(self): + # test arrays, bytes and various date/time scalars + + res = self.sess.exec( + select(self.sm.AssortedScalars) + ).one() + + self.assertEqual(res.name, 'hello world') + self.assertEqual(res.bstr, b'word\x00\x0b') + self.assertEqual( + res.time, + dt.time(20, 13, 45, 678_000), + ) + self.assertEqual( + res.date, + dt.date(2025, 1, 26), + ) + # time zone aware + self.assertEqual( + res.ts, + dt.datetime.fromisoformat('2025-01-26T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + res.lts, + dt.datetime.fromisoformat('2025-01-26T20:13:45'), + ) + def test_sqlmodel_create_models_01(self): vals = self.sess.exec( select(self.sm.User).where( @@ -591,6 +620,52 @@ def test_sqlmodel_update_models_04(self): post = self.sess.get(self.sm.Post, post_id) self.assertEqual(post.author.name, 'Zoe') + def test_sqlmodel_update_models_05(self): + # test arrays, bytes and various date/time scalars + # + # For the purpose of sending data creating and updating a model are + # both testing accurate data transfer. + + res = self.sess.exec( + select(self.sm.AssortedScalars) + ).one() + + res.name = 'New Name' + # res.vals.append('jumped') + res.bstr = b'\x01success\x02' + res.time = dt.time(8, 23, 54, 999_000) + res.date = dt.date(2020, 2, 14) + res.ts = res.ts - dt.timedelta(days=6) + res.lts = res.lts + dt.timedelta(days=6) + + self.sess.add(res) + self.sess.flush() + + upd = self.sess.exec( + select(self.sm.AssortedScalars) + ).one() + + self.assertEqual(upd.name, 'New Name') + self.assertEqual(upd.bstr, b'\x01success\x02') + self.assertEqual( + upd.time, + dt.time(8, 23, 54, 999_000), + ) + self.assertEqual( + upd.date, + dt.date(2020, 2, 14), + ) + # time zone aware + self.assertEqual( + upd.ts, + dt.datetime.fromisoformat('2025-01-20T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + upd.lts, + dt.datetime.fromisoformat('2025-02-01T20:13:45'), + ) + def test_sqlmodel_linkprops_01(self): val = self.sess.exec(select(self.sm.HasLinkPropsA)).one() self.assertEqual(val.child.target.num, 0)