Skip to content

Commit

Permalink
Missing custom sql in generated schema.sql (#1618)
Browse files Browse the repository at this point in the history
  • Loading branch information
lolopinto authored Aug 22, 2023
1 parent 8583c5a commit 55048d7
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 43 deletions.
14 changes: 14 additions & 0 deletions examples/simple/src/schema/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,17 @@ INSERT INTO assoc_edge_config(edge_name, edge_type, edge_table, symmetric_edge,
('UserToMaybeEventsEdge', '8d5b1dee-ce65-452e-9f8d-78eca1993800', 'event_rsvps_edges', false, 'b0f6311b-fdab-4c26-b6bf-b751e0997735', now() AT TIME ZONE 'UTC', now() AT TIME ZONE 'UTC'),
('UserToSelfContactEdge', 'd504201d-cf3f-4eef-b6a0-0b46a7ae186b', 'user_self_contact_edges', false, NULL, now() AT TIME ZONE 'UTC', now() AT TIME ZONE 'UTC') ON CONFLICT DO NOTHING;

-- custom sql for rev 63ec20382c27
CREATE OR REPLACE FUNCTION users_notify()
RETURNS trigger AS
$$
BEGIN
PERFORM pg_notify('users_created', NEW.id::text);
RETURN NEW;
END;
$$ LANGUAGE 'plpgsql';

CREATE OR REPLACE TRIGGER users_created BEFORE INSERT OR UPDATE
ON users
FOR EACH ROW EXECUTE PROCEDURE users_notify();;

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions internal/db/db_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,11 @@ func checkDBFilesInfo(cfg *codegen.Config) *dbFileInfo {
}
}

func compareDbFilesInfo(before, after *dbFileInfo) error {
func compareDbFilesInfo(processor *codegen.Processor, before, after *dbFileInfo) error {
// TODO we need a different way to handle this
if processor.Config.WriteAllFiles() {
return nil
}
// nothing to do here
if !before.useVersionsInfo || !after.useVersionsInfo {
return nil
Expand All @@ -676,6 +680,7 @@ func compareDbFilesInfo(before, after *dbFileInfo) error {
// TODO does this account for future formatting differences???

// time is not enough. we need to check the contents of the files
// should this be a prompt????
if after.schemaPy.exists && before.schemaPy.exists &&
string(before.schemaPy.contents) != string(after.schemaPy.contents) {
return fmt.Errorf("schema.py changed when no version files changed. there's an ent db error. you should file a bug report about what you were trying to do. it's probably unsupported")
Expand Down Expand Up @@ -717,7 +722,7 @@ func (s *dbSchema) makeDBChanges(processor *codegen.Processor) error {

after := checkDBFilesInfo(cfg)

return compareDbFilesInfo(s.before, after)
return compareDbFilesInfo(processor, s.before, after)
}

func UpgradeDB(cfg *codegen.Config, revision string, sql bool) error {
Expand Down
6 changes: 6 additions & 0 deletions python/auto_schema/auto_schema/clearable_string_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import io

class ClearableStringIO(io.StringIO):
def clear(self):
self.seek(0)
self.truncate(0)
20 changes: 9 additions & 11 deletions python/auto_schema/auto_schema/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ def revision(self, message, autogenerate=True):
return command.revision(self.alembic_cfg, message,
autogenerate=autogenerate, head=head)

def get_heads(self):
script = ScriptDirectory.from_config(self.alembic_cfg)

return script.get_heads()
def get_script_directory(self) -> ScriptDirectory:
return ScriptDirectory.from_config(self.alembic_cfg)

def get_heads(self):
return self.get_script_directory().get_heads()

def get_revisions(self, revs):
script = ScriptDirectory.from_config(self.alembic_cfg)
return script.get_revisions(revs)
return self.get_script_directory().get_revisions(revs)

# Simulates running the `alembic upgrade` command

Expand Down Expand Up @@ -94,8 +95,7 @@ def downgrade(self, revision='', delete_files=True):
os.remove(os.path.join(location, path))

def _get_paths_to_delete(self, revision):
script = ScriptDirectory.from_config(self.alembic_cfg)
revs = list(script.revision_map.iterate_revisions(
revs = list(self.get_script_directory().revision_map.iterate_revisions(
self.get_heads(), revision, select_for_downgrade=True
))

Expand All @@ -116,17 +116,15 @@ def _get_paths_to_delete(self, revision):
return result

def get_history(self):
script = ScriptDirectory.from_config(self.alembic_cfg)
return list(script.walk_revisions())
return list(self.get_script_directory().walk_revisions())

# Simulates running the `alembic history` command
def history(self, verbose=False, last=None, rev_range=None):
if rev_range is not None and last is not None:
raise ValueError(
"cannot pass both last and rev_range. please pick one")
if last is not None:
script = ScriptDirectory.from_config(self.alembic_cfg)
revs = list(script.revision_map.iterate_revisions(
revs = list(self.get_script_directory().revision_map.iterate_revisions(
self.get_heads(), '-%d' % int(last), select_for_downgrade=True
))
rev_range = '%s:current' % revs[-1].revision
Expand Down
2 changes: 1 addition & 1 deletion python/auto_schema/auto_schema/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def render_drop_full_text_index(autogen_context: AutogenContext, op: ops.DropFul
@renderers.dispatch_for(ops.ExecuteSQL)
def render_execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQL) -> str:
return (
"op.execute_sql('%(sql)s')" % {
'op.execute_sql("""%(sql)s""")' % {
"sql": op.sql,
}
)
107 changes: 81 additions & 26 deletions python/auto_schema/auto_schema/runner.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from argparse import Namespace
import io
import json
import sys
from collections.abc import Mapping
from alembic.operations import Operations
from alembic.util.langhelpers import Dispatcher
from uuid import UUID
from functools import wraps

from .diff import Diff
from .clause_text import get_clause_text
from .clearable_string_io import ClearableStringIO
import sqlalchemy as sa
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.engine.url import make_url

from alembic.migration import MigrationContext
from alembic.autogenerate import produce_migrations
from alembic.autogenerate import render_python_code
from alembic.migration import MigrationContext
from alembic import context
from alembic.util.exc import CommandError
from alembic.script import ScriptDirectory

from sqlalchemy.dialects import postgresql
import alembic.operations.ops as alembicops
from alembic.operations import Operations
from typing import Optional, Dict
from typing import Optional, Dict, Any

from . import command
from . import config
Expand Down Expand Up @@ -48,16 +53,20 @@ def __init__(self, metadata, engine, connection, schema_path, args: Optional[Dic

self.mc = MigrationContext.configure(
connection=self.connection,
# note that any change here also needs a comparable change in env.py
opts={
opts=Runner.get_opts(),
)
self.cmd = command.Command(self.connection, self.schema_path)

@classmethod
def get_opts(cls):
# note that any change here also needs a comparable change in env.py
return {
"compare_type": Runner.compare_type,
"include_object": Runner.include_object,
"compare_server_default": Runner.compare_server_default,
"transaction_per_migration": True,
"render_item": Runner.render_item,
},
)
self.cmd = command.Command(self.connection, self.schema_path)
}

@classmethod
def from_command_line(cls, metadata, args: Namespace):
Expand Down Expand Up @@ -294,6 +303,58 @@ def merge(self, revisions, message=None):
def squash(self, squash):
self.cmd.squash(self.revision, squash)


def _get_custom_sql(self, connection, dialect) -> io.StringIO:
script_directory = self.cmd.get_script_directory()
revs = script_directory.walk_revisions()

# this is cleared after each upgrade
temp_buffer = ClearableStringIO()

opts = Runner.get_opts()
opts['as_sql'] = True
opts['output_buffer'] = temp_buffer
mc = MigrationContext.configure(
connection=connection,
dialect_name=dialect,
opts=opts,
)

custom_sql_buffer = io.StringIO()

# monkey patch the Dispatcher.dispatch method to know what's being changed/dispatched
# for each upgrade path, we'll know what the last object was and can make decisions based on that

last_obj = None

def my_decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
nonlocal last_obj
last_obj = args[0]
return func(self, *args, **kwargs)
return wrapper

Dispatcher.dispatch = my_decorator(Dispatcher.dispatch)

# order is flipped, it goes from most recent to oldest
# we want to go from oldest -> most recent
revs = list(revs)
revs.reverse()

with Operations.context(mc):
for rev in revs:
# run upgrade(), we capture what's being changed via the dispatcher and see if it's custom sql
rev.module.upgrade()

if isinstance(last_obj, ops.ExecuteSQL) or isinstance(last_obj, alembicops.ExecuteSQLOp):
custom_sql_buffer.write("-- custom sql for rev %s\n" % rev.revision)
custom_sql_buffer.write(temp_buffer.getvalue())

temp_buffer.clear()

return custom_sql_buffer

# doesn't invoke env.py. completely different flow
# progressive_sql and upgrade range do go through offline path
def all_sql(self, file=None, database=''):
Expand All @@ -318,34 +379,22 @@ def all_sql(self, file=None, database=''):
mc = MigrationContext.configure(
connection=connection,
dialect_name=dialect,
# note that any change here also needs a comparable change in env.py
opts={
"compare_type": Runner.compare_type,
"include_object": Runner.include_object,
"compare_server_default": Runner.compare_server_default,
"transaction_per_migration": True,
"render_item": Runner.render_item,
},
opts=Runner.get_opts(),
)
migrations = produce_migrations(mc, self.metadata)

# default is stdout so let's use it
buffer = sys.stdout
if file is not None:
buffer = open(file, 'w')

# use different migrations context with as_sql so that we don't have issues
opts = Runner.get_opts()
opts['as_sql'] = True
opts['output_buffer'] = buffer
mc2 = MigrationContext.configure(
connection=connection,
# note that any change here also needs a comparable change in env.py
opts={
"compare_type": Runner.compare_type,
"include_object": Runner.include_object,
"compare_server_default": Runner.compare_server_default,
"render_item": Runner.render_item,
"as_sql": True,
"output_buffer": buffer,
},
opts=opts
)

# let's do a consistent (not runtime dependent) sort of constraints by using name instead of _creation_order
Expand All @@ -369,6 +418,12 @@ def invoke(op):

for op in migrations.upgrade_ops.ops:
invoke(op)

custom_sql_buffer = self._get_custom_sql(connection, dialect)

# add custom sql at the end
buffer.write(custom_sql_buffer.getvalue())


def progressive_sql(self, file=None):
if file is not None:
Expand Down
5 changes: 3 additions & 2 deletions python/auto_schema/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# https://pypi.org/project/auto-schema/#history
# https://test.pypi.org/project/auto-schema-test/#history
setuptools.setup(
name="auto_schema", # auto_schema_test to test
version="0.0.29", # 0.0.24 was last test version
name="auto_schema_test", # auto_schema_test to test
# 0.0.29 production
version="0.0.26", # 0.0.26 was last test version
author="Ola Okelola",
author_email="[email protected]",
description="auto schema for a db",
Expand Down
1 change: 1 addition & 0 deletions python/auto_schema/tests/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,7 @@ def test_custom_sql(self, new_test_runner, metadata_with_table):
assert upgrade_start != -1
assert downgrade_start != -1

# TODO render_python_code in alembic could be helpful?
# "edit the file " to add types
new_upgrade = """def upgrade():
op.execute_sql("CREATE TYPE rainbow as ENUM ('red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet')")
Expand Down

0 comments on commit 55048d7

Please sign in to comment.