From 55048d7175361443e81d8f6356610ee8aec2a407 Mon Sep 17 00:00:00 2001 From: Ola Okelola <10857143+lolopinto@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:20:41 -0700 Subject: [PATCH] Missing custom sql in generated schema.sql (#1618) --- examples/simple/src/schema/schema.sql | 14 +++ .../63ec20382c27_2023720223223_ggsgsf.py | 2 +- internal/db/db_schema.go | 9 +- .../auto_schema/clearable_string_io.py | 6 + python/auto_schema/auto_schema/command.py | 20 ++-- python/auto_schema/auto_schema/renderers.py | 2 +- python/auto_schema/auto_schema/runner.py | 107 +++++++++++++----- python/auto_schema/setup.py | 5 +- python/auto_schema/tests/runner_test.py | 1 + 9 files changed, 123 insertions(+), 43 deletions(-) create mode 100644 python/auto_schema/auto_schema/clearable_string_io.py diff --git a/examples/simple/src/schema/schema.sql b/examples/simple/src/schema/schema.sql index 4cde6ddea..413217cbe 100644 --- a/examples/simple/src/schema/schema.sql +++ b/examples/simple/src/schema/schema.sql @@ -319,3 +319,17 @@ INSERT INTO assoc_edge_config(edge_name, edge_type, edge_table, symmetric_edge, ('UserToMaybeEventsEdge', '8d5b1dee-ce65-452e-9f8d-78eca1993800', 'event_rsvps_edges', false, 'b0f6311b-fdab-4c26-b6bf-b751e0997735', now() AT TIME ZONE 'UTC', now() AT TIME ZONE 'UTC'), ('UserToSelfContactEdge', 'd504201d-cf3f-4eef-b6a0-0b46a7ae186b', 'user_self_contact_edges', false, NULL, now() AT TIME ZONE 'UTC', now() AT TIME ZONE 'UTC') ON CONFLICT DO NOTHING; +-- custom sql for rev 63ec20382c27 +CREATE OR REPLACE FUNCTION users_notify() +RETURNS trigger AS +$$ +BEGIN + PERFORM pg_notify('users_created', NEW.id::text); + RETURN NEW; +END; +$$ LANGUAGE 'plpgsql'; + +CREATE OR REPLACE TRIGGER users_created BEFORE INSERT OR UPDATE + ON users + FOR EACH ROW EXECUTE PROCEDURE users_notify();; + diff --git a/examples/simple/src/schema/versions/63ec20382c27_2023720223223_ggsgsf.py b/examples/simple/src/schema/versions/63ec20382c27_2023720223223_ggsgsf.py index 3a7796266..a2a989152 100644 --- a/examples/simple/src/schema/versions/63ec20382c27_2023720223223_ggsgsf.py +++ b/examples/simple/src/schema/versions/63ec20382c27_2023720223223_ggsgsf.py @@ -27,7 +27,7 @@ def upgrade(): RETURNS trigger AS $$ BEGIN - PERFORM pg_notify(’users_created’, NEW.id::text); + PERFORM pg_notify('users_created', NEW.id::text); RETURN NEW; END; $$ LANGUAGE 'plpgsql'; diff --git a/internal/db/db_schema.go b/internal/db/db_schema.go index 4f3c2bd91..35ea513a3 100644 --- a/internal/db/db_schema.go +++ b/internal/db/db_schema.go @@ -663,7 +663,11 @@ func checkDBFilesInfo(cfg *codegen.Config) *dbFileInfo { } } -func compareDbFilesInfo(before, after *dbFileInfo) error { +func compareDbFilesInfo(processor *codegen.Processor, before, after *dbFileInfo) error { + // TODO we need a different way to handle this + if processor.Config.WriteAllFiles() { + return nil + } // nothing to do here if !before.useVersionsInfo || !after.useVersionsInfo { return nil @@ -676,6 +680,7 @@ func compareDbFilesInfo(before, after *dbFileInfo) error { // TODO does this account for future formatting differences??? // time is not enough. we need to check the contents of the files + // should this be a prompt???? if after.schemaPy.exists && before.schemaPy.exists && string(before.schemaPy.contents) != string(after.schemaPy.contents) { return fmt.Errorf("schema.py changed when no version files changed. there's an ent db error. you should file a bug report about what you were trying to do. it's probably unsupported") @@ -717,7 +722,7 @@ func (s *dbSchema) makeDBChanges(processor *codegen.Processor) error { after := checkDBFilesInfo(cfg) - return compareDbFilesInfo(s.before, after) + return compareDbFilesInfo(processor, s.before, after) } func UpgradeDB(cfg *codegen.Config, revision string, sql bool) error { diff --git a/python/auto_schema/auto_schema/clearable_string_io.py b/python/auto_schema/auto_schema/clearable_string_io.py new file mode 100644 index 000000000..986286c9b --- /dev/null +++ b/python/auto_schema/auto_schema/clearable_string_io.py @@ -0,0 +1,6 @@ +import io + +class ClearableStringIO(io.StringIO): + def clear(self): + self.seek(0) + self.truncate(0) \ No newline at end of file diff --git a/python/auto_schema/auto_schema/command.py b/python/auto_schema/auto_schema/command.py index d4981fd4a..f0d9e46c5 100644 --- a/python/auto_schema/auto_schema/command.py +++ b/python/auto_schema/auto_schema/command.py @@ -56,14 +56,15 @@ def revision(self, message, autogenerate=True): return command.revision(self.alembic_cfg, message, autogenerate=autogenerate, head=head) - def get_heads(self): - script = ScriptDirectory.from_config(self.alembic_cfg) - return script.get_heads() + def get_script_directory(self) -> ScriptDirectory: + return ScriptDirectory.from_config(self.alembic_cfg) + + def get_heads(self): + return self.get_script_directory().get_heads() def get_revisions(self, revs): - script = ScriptDirectory.from_config(self.alembic_cfg) - return script.get_revisions(revs) + return self.get_script_directory().get_revisions(revs) # Simulates running the `alembic upgrade` command @@ -94,8 +95,7 @@ def downgrade(self, revision='', delete_files=True): os.remove(os.path.join(location, path)) def _get_paths_to_delete(self, revision): - script = ScriptDirectory.from_config(self.alembic_cfg) - revs = list(script.revision_map.iterate_revisions( + revs = list(self.get_script_directory().revision_map.iterate_revisions( self.get_heads(), revision, select_for_downgrade=True )) @@ -116,8 +116,7 @@ def _get_paths_to_delete(self, revision): return result def get_history(self): - script = ScriptDirectory.from_config(self.alembic_cfg) - return list(script.walk_revisions()) + return list(self.get_script_directory().walk_revisions()) # Simulates running the `alembic history` command def history(self, verbose=False, last=None, rev_range=None): @@ -125,8 +124,7 @@ def history(self, verbose=False, last=None, rev_range=None): raise ValueError( "cannot pass both last and rev_range. please pick one") if last is not None: - script = ScriptDirectory.from_config(self.alembic_cfg) - revs = list(script.revision_map.iterate_revisions( + revs = list(self.get_script_directory().revision_map.iterate_revisions( self.get_heads(), '-%d' % int(last), select_for_downgrade=True )) rev_range = '%s:current' % revs[-1].revision diff --git a/python/auto_schema/auto_schema/renderers.py b/python/auto_schema/auto_schema/renderers.py index 9b6d2d2fb..508e95e30 100644 --- a/python/auto_schema/auto_schema/renderers.py +++ b/python/auto_schema/auto_schema/renderers.py @@ -205,7 +205,7 @@ def render_drop_full_text_index(autogen_context: AutogenContext, op: ops.DropFul @renderers.dispatch_for(ops.ExecuteSQL) def render_execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQL) -> str: return ( - "op.execute_sql('%(sql)s')" % { + 'op.execute_sql("""%(sql)s""")' % { "sql": op.sql, } ) diff --git a/python/auto_schema/auto_schema/runner.py b/python/auto_schema/auto_schema/runner.py index bf78b7956..53113ef29 100644 --- a/python/auto_schema/auto_schema/runner.py +++ b/python/auto_schema/auto_schema/runner.py @@ -1,25 +1,30 @@ from argparse import Namespace +import io import json import sys from collections.abc import Mapping from alembic.operations import Operations +from alembic.util.langhelpers import Dispatcher from uuid import UUID +from functools import wraps from .diff import Diff from .clause_text import get_clause_text +from .clearable_string_io import ClearableStringIO import sqlalchemy as sa from sqlalchemy.sql.elements import TextClause from sqlalchemy.engine.url import make_url -from alembic.migration import MigrationContext from alembic.autogenerate import produce_migrations from alembic.autogenerate import render_python_code +from alembic.migration import MigrationContext +from alembic import context from alembic.util.exc import CommandError +from alembic.script import ScriptDirectory from sqlalchemy.dialects import postgresql import alembic.operations.ops as alembicops -from alembic.operations import Operations -from typing import Optional, Dict +from typing import Optional, Dict, Any from . import command from . import config @@ -48,16 +53,20 @@ def __init__(self, metadata, engine, connection, schema_path, args: Optional[Dic self.mc = MigrationContext.configure( connection=self.connection, - # note that any change here also needs a comparable change in env.py - opts={ + opts=Runner.get_opts(), + ) + self.cmd = command.Command(self.connection, self.schema_path) + + @classmethod + def get_opts(cls): + # note that any change here also needs a comparable change in env.py + return { "compare_type": Runner.compare_type, "include_object": Runner.include_object, "compare_server_default": Runner.compare_server_default, "transaction_per_migration": True, "render_item": Runner.render_item, - }, - ) - self.cmd = command.Command(self.connection, self.schema_path) + } @classmethod def from_command_line(cls, metadata, args: Namespace): @@ -294,6 +303,58 @@ def merge(self, revisions, message=None): def squash(self, squash): self.cmd.squash(self.revision, squash) + + def _get_custom_sql(self, connection, dialect) -> io.StringIO: + script_directory = self.cmd.get_script_directory() + revs = script_directory.walk_revisions() + + # this is cleared after each upgrade + temp_buffer = ClearableStringIO() + + opts = Runner.get_opts() + opts['as_sql'] = True + opts['output_buffer'] = temp_buffer + mc = MigrationContext.configure( + connection=connection, + dialect_name=dialect, + opts=opts, + ) + + custom_sql_buffer = io.StringIO() + + # monkey patch the Dispatcher.dispatch method to know what's being changed/dispatched + # for each upgrade path, we'll know what the last object was and can make decisions based on that + + last_obj = None + + def my_decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + nonlocal last_obj + last_obj = args[0] + return func(self, *args, **kwargs) + return wrapper + + Dispatcher.dispatch = my_decorator(Dispatcher.dispatch) + + # order is flipped, it goes from most recent to oldest + # we want to go from oldest -> most recent + revs = list(revs) + revs.reverse() + + with Operations.context(mc): + for rev in revs: + # run upgrade(), we capture what's being changed via the dispatcher and see if it's custom sql + rev.module.upgrade() + + if isinstance(last_obj, ops.ExecuteSQL) or isinstance(last_obj, alembicops.ExecuteSQLOp): + custom_sql_buffer.write("-- custom sql for rev %s\n" % rev.revision) + custom_sql_buffer.write(temp_buffer.getvalue()) + + temp_buffer.clear() + + return custom_sql_buffer + # doesn't invoke env.py. completely different flow # progressive_sql and upgrade range do go through offline path def all_sql(self, file=None, database=''): @@ -318,34 +379,22 @@ def all_sql(self, file=None, database=''): mc = MigrationContext.configure( connection=connection, dialect_name=dialect, - # note that any change here also needs a comparable change in env.py - opts={ - "compare_type": Runner.compare_type, - "include_object": Runner.include_object, - "compare_server_default": Runner.compare_server_default, - "transaction_per_migration": True, - "render_item": Runner.render_item, - }, + opts=Runner.get_opts(), ) migrations = produce_migrations(mc, self.metadata) - + # default is stdout so let's use it buffer = sys.stdout if file is not None: buffer = open(file, 'w') # use different migrations context with as_sql so that we don't have issues + opts = Runner.get_opts() + opts['as_sql'] = True + opts['output_buffer'] = buffer mc2 = MigrationContext.configure( connection=connection, - # note that any change here also needs a comparable change in env.py - opts={ - "compare_type": Runner.compare_type, - "include_object": Runner.include_object, - "compare_server_default": Runner.compare_server_default, - "render_item": Runner.render_item, - "as_sql": True, - "output_buffer": buffer, - }, + opts=opts ) # let's do a consistent (not runtime dependent) sort of constraints by using name instead of _creation_order @@ -369,6 +418,12 @@ def invoke(op): for op in migrations.upgrade_ops.ops: invoke(op) + + custom_sql_buffer = self._get_custom_sql(connection, dialect) + + # add custom sql at the end + buffer.write(custom_sql_buffer.getvalue()) + def progressive_sql(self, file=None): if file is not None: diff --git a/python/auto_schema/setup.py b/python/auto_schema/setup.py index c7da64589..b27a8f7c9 100644 --- a/python/auto_schema/setup.py +++ b/python/auto_schema/setup.py @@ -6,8 +6,9 @@ # https://pypi.org/project/auto-schema/#history # https://test.pypi.org/project/auto-schema-test/#history setuptools.setup( - name="auto_schema", # auto_schema_test to test - version="0.0.29", # 0.0.24 was last test version + name="auto_schema_test", # auto_schema_test to test + # 0.0.29 production + version="0.0.26", # 0.0.26 was last test version author="Ola Okelola", author_email="email@email.com", description="auto schema for a db", diff --git a/python/auto_schema/tests/runner_test.py b/python/auto_schema/tests/runner_test.py index 14e59246a..a9d1e3590 100644 --- a/python/auto_schema/tests/runner_test.py +++ b/python/auto_schema/tests/runner_test.py @@ -1638,6 +1638,7 @@ def test_custom_sql(self, new_test_runner, metadata_with_table): assert upgrade_start != -1 assert downgrade_start != -1 + # TODO render_python_code in alembic could be helpful? # "edit the file " to add types new_upgrade = """def upgrade(): op.execute_sql("CREATE TYPE rainbow as ENUM ('red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet')")