From e90d6b9257924a42216b9eb4e37b5b3b6eb6804e Mon Sep 17 00:00:00 2001 From: Ronen Lubin <63970571+ronenlu@users.noreply.github.com> Date: Tue, 23 Jan 2024 17:31:35 +0200 Subject: [PATCH] dont relay on exisitng django migration for the command to run (#5) --- .../commands/atlas-provider-django.py | 87 +++---------------- .../management/commands/migrations.py | 22 +++++ atlas_provider_django/settings.py | 2 +- tests/{app => app1}/__init__.py | 0 tests/{app => app1}/apps.py | 4 +- tests/{app => app1}/migrations/__init__.py | 0 tests/{app => app1}/models.py | 0 tests/app2/__init__.py | 0 tests/app2/apps.py | 6 ++ tests/app2/migrations/__init__.py | 0 tests/app2/models.py | 14 +++ 11 files changed, 59 insertions(+), 76 deletions(-) create mode 100644 atlas_provider_django/management/commands/migrations.py rename tests/{app => app1}/__init__.py (100%) rename tests/{app => app1}/apps.py (63%) rename tests/{app => app1}/migrations/__init__.py (100%) rename tests/{app => app1}/models.py (100%) create mode 100644 tests/app2/__init__.py create mode 100644 tests/app2/apps.py create mode 100644 tests/app2/migrations/__init__.py create mode 100644 tests/app2/models.py diff --git a/atlas_provider_django/management/commands/atlas-provider-django.py b/atlas_provider_django/management/commands/atlas-provider-django.py index 61a0854..7e48cf0 100644 --- a/atlas_provider_django/management/commands/atlas-provider-django.py +++ b/atlas_provider_django/management/commands/atlas-provider-django.py @@ -4,9 +4,8 @@ from django.apps import apps from django.core.management import call_command -from django.db.migrations.exceptions import NodeNotFoundError from django.db.migrations.graph import MigrationGraph -from django.db.migrations.loader import MigrationLoader, AmbiguityError +from django.db.migrations.loader import MigrationLoader from django.core.management.base import BaseCommand, CommandError from django.core.management.commands.sqlmigrate import Command as SqlMigrateCommand from django.db.backends.sqlite3.base import DatabaseWrapper as Sqlite3DatabaseWrapper @@ -17,6 +16,8 @@ from django.db.backends.mysql.schema import DatabaseSchemaEditor as MySQLDatabaseSchemaEditor from django.db.backends.mysql.features import DatabaseFeatures as MySQLDatabaseFeatures +from atlas_provider_django.management.commands.migrations import get_migrations + class Dialect(str, Enum): mysql = "mysql" @@ -98,65 +99,23 @@ def get_connection_by_dialect(dialect): # MockMigrationLoader loads migrations without db connection. class MockMigrationLoader(MigrationLoader): - def __init__(self, connection, replace_migrations=False, load=True): + def __init__(self, connection, replace_migrations=False, load=False): super().__init__(connection, replace_migrations, load) - # The method is almost the same as the original one, but it doesn't check if the migrations are applied or not. - # Copied from Django's MigrationLoader class: https://github.com/django/django/blob/8a1727dc7f66db7f0131d545812f77544f35aa57/django/db/migrations/loader.py#L222-L305 - # Code licensed under the BSD 3-Clause License: https://github.com/django/django/blob/main/LICENSE def build_graph(self): - self.load_disk() + self.disk_migrations = get_migrations() self.applied_migrations = {} + self.unmigrated_apps = set() + self.migrated_apps = set() self.graph = MigrationGraph() - self.replacements = {} for key, migration in self.disk_migrations.items(): self.graph.add_node(key, migration) - if migration.replaces: - self.replacements[key] = migration - for key, migration in self.disk_migrations.items(): - self.add_internal_dependencies(key, migration) - for key, migration in self.disk_migrations.items(): - self.add_external_dependencies(key, migration) - if self.replace_migrations: - for key, migration in self.replacements.items(): - applied_statuses = [ - (target in self.applied_migrations) for target in migration.replaces - ] - if all(applied_statuses): - self.applied_migrations[key] = migration - else: - self.applied_migrations.pop(key, None) - if all(applied_statuses) or (not any(applied_statuses)): - self.graph.remove_replaced_nodes(key, migration.replaces) - else: - self.graph.remove_replacement_node(key, migration.replaces) - try: - self.graph.validate_consistency() - except NodeNotFoundError as exc: - reverse_replacements = {} - for key, migration in self.replacements.items(): - for replaced in migration.replaces: - reverse_replacements.setdefault(replaced, set()).add(key) - if exc.node in reverse_replacements: - candidates = reverse_replacements.get(exc.node, set()) - is_replaced = any( - candidate in self.graph.nodes for candidate in candidates - ) - if not is_replaced: - tries = ", ".join("%s.%s" % c for c in candidates) - raise NodeNotFoundError( - "Migration {0} depends on nonexistent node ('{1}', '{2}'). " - "Django tried to replace migration {1}.{2} with any of [{3}] " - "but wasn't able to because some of the replaced migrations " - "are already applied.".format( - exc.origin, exc.node[0], exc.node[1], tries - ), - exc.node, - ) from exc - raise + self.graph.validate_consistency() self.graph.ensure_not_cyclic() # The method is almost the same as the original one, but it doesn't check if atomic transactions are enabled or not. + # Copied from Django's MigrationLoader class: https://github.com/django/django/blob/8a1727dc7f66db7f0131d545812f77544f35aa57/django/db/migrations/loader.py#L365-L385 + # Code licensed under the BSD 3-Clause License: https://github.com/django/django/blob/main/LICENSE def collect_sql(self, plan): statements = [] state = None @@ -188,15 +147,8 @@ def mock_handle(self, *args, **options): apps.get_app_config(app_label) except LookupError as err: raise CommandError(str(err)) - if app_label not in loader.migrated_apps: - raise CommandError("App '%s' does not have migrations" % app_label) try: migration = loader.get_migration_by_prefix(app_label, migration_name) - except AmbiguityError: - raise CommandError( - "More than one migration matches '%s' in app '%s'. Please be more " - "specific." % (migration_name, app_label) - ) except KeyError: raise CommandError( "Cannot find a migration matching '%s' from app '%s'. Is it in " @@ -210,13 +162,6 @@ def mock_handle(self, *args, **options): return "\n".join(sql_statements) -def order_migrations_by_dependency(): - loader = MigrationLoader(None) - graph = loader.graph - all_nodes = graph.nodes - return graph._generate_plan(nodes=all_nodes, at_end=True) - - class Command(BaseCommand): help = "Import Django migrations into Atlas" @@ -229,7 +174,6 @@ def add_arguments(self, parser): def handle(self, *args, **options): global current_dialect current_dialect = options.get("dialect", Dialect.sqlite) - self.create_migrations() print(self.get_ddl()) def create_migrations(self): @@ -247,11 +191,8 @@ def create_migrations(self): # Load migrations and get the sql statements describing the migrations. def get_ddl(self): - migration_loader = MigrationLoader(None, ignore_no_migrations=True) - migration_loader.load_disk() - migrations = "" - ordered_migrations = order_migrations_by_dependency() - for app_name, migration_name in ordered_migrations: + ddl = "" + for app_name, migration_name in get_migrations(): try: out = StringIO() call_command( @@ -261,7 +202,7 @@ def get_ddl(self): stdout=out, stderr=StringIO(), ) - migrations += out.getvalue() + ddl += out.getvalue() except Exception as e: traceback.print_exc() self.stderr.write( @@ -269,4 +210,4 @@ def get_ddl(self): ) exit(1) - return migrations + return ddl diff --git a/atlas_provider_django/management/commands/migrations.py b/atlas_provider_django/management/commands/migrations.py new file mode 100644 index 0000000..d43cb5e --- /dev/null +++ b/atlas_provider_django/management/commands/migrations.py @@ -0,0 +1,22 @@ +from django.db.migrations.autodetector import MigrationAutodetector +from django.db.migrations.state import ProjectState +from django.db.migrations.loader import MigrationLoader +from django.apps import apps + + +# Creates the migrations of the installed apps from empty baseline and returns them as a dictionary +def get_migrations(): + autodetector = MigrationAutodetector( + ProjectState(), + ProjectState.from_apps(apps), + ) + loader = MigrationLoader(None, ignore_no_migrations=True) + changes = autodetector.changes( + graph=loader.graph, + trim_to_apps=None, + convert_apps=None, + ) + migrations = {} + for app_label, app_migrations in changes.items(): + migrations[(app_label, app_migrations[0].name)] = app_migrations[0] + return migrations diff --git a/atlas_provider_django/settings.py b/atlas_provider_django/settings.py index b5abbc1..db074e3 100644 --- a/atlas_provider_django/settings.py +++ b/atlas_provider_django/settings.py @@ -1 +1 @@ -INSTALLED_APPS = ["atlas_provider_django", "tests.app"] +INSTALLED_APPS = ["atlas_provider_django", "tests.app1", "tests.app2"] diff --git a/tests/app/__init__.py b/tests/app1/__init__.py similarity index 100% rename from tests/app/__init__.py rename to tests/app1/__init__.py diff --git a/tests/app/apps.py b/tests/app1/apps.py similarity index 63% rename from tests/app/apps.py rename to tests/app1/apps.py index 40424a3..85b7aca 100644 --- a/tests/app/apps.py +++ b/tests/app1/apps.py @@ -1,6 +1,6 @@ from django.apps import AppConfig -class AppConfig(AppConfig): +class App1Config(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "tests.app" + name = "tests.app1" diff --git a/tests/app/migrations/__init__.py b/tests/app1/migrations/__init__.py similarity index 100% rename from tests/app/migrations/__init__.py rename to tests/app1/migrations/__init__.py diff --git a/tests/app/models.py b/tests/app1/models.py similarity index 100% rename from tests/app/models.py rename to tests/app1/models.py diff --git a/tests/app2/__init__.py b/tests/app2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app2/apps.py b/tests/app2/apps.py new file mode 100644 index 0000000..03cc00f --- /dev/null +++ b/tests/app2/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class App2Config(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "tests.app2" diff --git a/tests/app2/migrations/__init__.py b/tests/app2/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app2/models.py b/tests/app2/models.py new file mode 100644 index 0000000..29327bb --- /dev/null +++ b/tests/app2/models.py @@ -0,0 +1,14 @@ +from django.db import models + + +class User(models.Model): + first_name = models.CharField(max_length=50) + last_name = models.CharField(max_length=50) + roll = models.CharField(max_length=100) + + +class Blog(models.Model): + author = models.ForeignKey(User, on_delete=models.CASCADE) + name = models.CharField(max_length=100) + created_at = models.DateField() + num_stars = models.IntegerField()