Skip to content

Commit

Permalink
dont relay on exisitng django migration for the command to run (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenlu authored Jan 23, 2024
1 parent 2d4dde7 commit e90d6b9
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 76 deletions.
87 changes: 14 additions & 73 deletions atlas_provider_django/management/commands/atlas-provider-django.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from django.apps import apps
from django.core.management import call_command
from django.db.migrations.exceptions import NodeNotFoundError
from django.db.migrations.graph import MigrationGraph
from django.db.migrations.loader import MigrationLoader, AmbiguityError
from django.db.migrations.loader import MigrationLoader
from django.core.management.base import BaseCommand, CommandError
from django.core.management.commands.sqlmigrate import Command as SqlMigrateCommand
from django.db.backends.sqlite3.base import DatabaseWrapper as Sqlite3DatabaseWrapper
Expand All @@ -17,6 +16,8 @@
from django.db.backends.mysql.schema import DatabaseSchemaEditor as MySQLDatabaseSchemaEditor
from django.db.backends.mysql.features import DatabaseFeatures as MySQLDatabaseFeatures

from atlas_provider_django.management.commands.migrations import get_migrations


class Dialect(str, Enum):
mysql = "mysql"
Expand Down Expand Up @@ -98,65 +99,23 @@ def get_connection_by_dialect(dialect):

# MockMigrationLoader loads migrations without db connection.
class MockMigrationLoader(MigrationLoader):
def __init__(self, connection, replace_migrations=False, load=True):
def __init__(self, connection, replace_migrations=False, load=False):
super().__init__(connection, replace_migrations, load)

# The method is almost the same as the original one, but it doesn't check if the migrations are applied or not.
# Copied from Django's MigrationLoader class: https://github.com/django/django/blob/8a1727dc7f66db7f0131d545812f77544f35aa57/django/db/migrations/loader.py#L222-L305
# Code licensed under the BSD 3-Clause License: https://github.com/django/django/blob/main/LICENSE
def build_graph(self):
self.load_disk()
self.disk_migrations = get_migrations()
self.applied_migrations = {}
self.unmigrated_apps = set()
self.migrated_apps = set()
self.graph = MigrationGraph()
self.replacements = {}
for key, migration in self.disk_migrations.items():
self.graph.add_node(key, migration)
if migration.replaces:
self.replacements[key] = migration
for key, migration in self.disk_migrations.items():
self.add_internal_dependencies(key, migration)
for key, migration in self.disk_migrations.items():
self.add_external_dependencies(key, migration)
if self.replace_migrations:
for key, migration in self.replacements.items():
applied_statuses = [
(target in self.applied_migrations) for target in migration.replaces
]
if all(applied_statuses):
self.applied_migrations[key] = migration
else:
self.applied_migrations.pop(key, None)
if all(applied_statuses) or (not any(applied_statuses)):
self.graph.remove_replaced_nodes(key, migration.replaces)
else:
self.graph.remove_replacement_node(key, migration.replaces)
try:
self.graph.validate_consistency()
except NodeNotFoundError as exc:
reverse_replacements = {}
for key, migration in self.replacements.items():
for replaced in migration.replaces:
reverse_replacements.setdefault(replaced, set()).add(key)
if exc.node in reverse_replacements:
candidates = reverse_replacements.get(exc.node, set())
is_replaced = any(
candidate in self.graph.nodes for candidate in candidates
)
if not is_replaced:
tries = ", ".join("%s.%s" % c for c in candidates)
raise NodeNotFoundError(
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
"Django tried to replace migration {1}.{2} with any of [{3}] "
"but wasn't able to because some of the replaced migrations "
"are already applied.".format(
exc.origin, exc.node[0], exc.node[1], tries
),
exc.node,
) from exc
raise
self.graph.validate_consistency()
self.graph.ensure_not_cyclic()

# The method is almost the same as the original one, but it doesn't check if atomic transactions are enabled or not.
# Copied from Django's MigrationLoader class: https://github.com/django/django/blob/8a1727dc7f66db7f0131d545812f77544f35aa57/django/db/migrations/loader.py#L365-L385
# Code licensed under the BSD 3-Clause License: https://github.com/django/django/blob/main/LICENSE
def collect_sql(self, plan):
statements = []
state = None
Expand Down Expand Up @@ -188,15 +147,8 @@ def mock_handle(self, *args, **options):
apps.get_app_config(app_label)
except LookupError as err:
raise CommandError(str(err))
if app_label not in loader.migrated_apps:
raise CommandError("App '%s' does not have migrations" % app_label)
try:
migration = loader.get_migration_by_prefix(app_label, migration_name)
except AmbiguityError:
raise CommandError(
"More than one migration matches '%s' in app '%s'. Please be more "
"specific." % (migration_name, app_label)
)
except KeyError:
raise CommandError(
"Cannot find a migration matching '%s' from app '%s'. Is it in "
Expand All @@ -210,13 +162,6 @@ def mock_handle(self, *args, **options):
return "\n".join(sql_statements)


def order_migrations_by_dependency():
loader = MigrationLoader(None)
graph = loader.graph
all_nodes = graph.nodes
return graph._generate_plan(nodes=all_nodes, at_end=True)


class Command(BaseCommand):
help = "Import Django migrations into Atlas"

Expand All @@ -229,7 +174,6 @@ def add_arguments(self, parser):
def handle(self, *args, **options):
global current_dialect
current_dialect = options.get("dialect", Dialect.sqlite)
self.create_migrations()
print(self.get_ddl())

def create_migrations(self):
Expand All @@ -247,11 +191,8 @@ def create_migrations(self):

# Load migrations and get the sql statements describing the migrations.
def get_ddl(self):
migration_loader = MigrationLoader(None, ignore_no_migrations=True)
migration_loader.load_disk()
migrations = ""
ordered_migrations = order_migrations_by_dependency()
for app_name, migration_name in ordered_migrations:
ddl = ""
for app_name, migration_name in get_migrations():
try:
out = StringIO()
call_command(
Expand All @@ -261,12 +202,12 @@ def get_ddl(self):
stdout=out,
stderr=StringIO(),
)
migrations += out.getvalue()
ddl += out.getvalue()
except Exception as e:
traceback.print_exc()
self.stderr.write(
f"failed to get migration {app_name} {migration_name}, {e}"
)
exit(1)

return migrations
return ddl
22 changes: 22 additions & 0 deletions atlas_provider_django/management/commands/migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.state import ProjectState
from django.db.migrations.loader import MigrationLoader
from django.apps import apps


# Creates the migrations of the installed apps from empty baseline and returns them as a dictionary
def get_migrations():
autodetector = MigrationAutodetector(
ProjectState(),
ProjectState.from_apps(apps),
)
loader = MigrationLoader(None, ignore_no_migrations=True)
changes = autodetector.changes(
graph=loader.graph,
trim_to_apps=None,
convert_apps=None,
)
migrations = {}
for app_label, app_migrations in changes.items():
migrations[(app_label, app_migrations[0].name)] = app_migrations[0]
return migrations
2 changes: 1 addition & 1 deletion atlas_provider_django/settings.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSTALLED_APPS = ["atlas_provider_django", "tests.app"]
INSTALLED_APPS = ["atlas_provider_django", "tests.app1", "tests.app2"]
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/app/apps.py → tests/app1/apps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.apps import AppConfig


class AppConfig(AppConfig):
class App1Config(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "tests.app"
name = "tests.app1"
File renamed without changes.
File renamed without changes.
Empty file added tests/app2/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions tests/app2/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class App2Config(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "tests.app2"
Empty file.
14 changes: 14 additions & 0 deletions tests/app2/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from django.db import models


class User(models.Model):
first_name = models.CharField(max_length=50)
last_name = models.CharField(max_length=50)
roll = models.CharField(max_length=100)


class Blog(models.Model):
author = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=100)
created_at = models.DateField()
num_stars = models.IntegerField()

0 comments on commit e90d6b9

Please sign in to comment.