Skip to content

Commit

Permalink
add option to get ddl only from selected apps (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenlu authored Jan 24, 2024
1 parent fbd497e commit 8074450
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ jobs:
run: poetry install
- name: Run lint
run: poetry run ruff --output-format=github .
- name: Run unit tests
run: poetry run python manage.py test

integration-tests:
strategy:
Expand Down
28 changes: 9 additions & 19 deletions atlas_provider_django/management/commands/atlas-provider-django.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __str__(self):
return self.value


current_dialect = Dialect.sqlite
current_dialect = Dialect.mysql


class MockSqliteSchemaEditor(SqliteSchemaEditor):
Expand Down Expand Up @@ -182,33 +182,23 @@ class Command(BaseCommand):
help = "Import Django migrations into Atlas"

def add_arguments(self, parser):
parser.add_argument("--dialect", type=Dialect, choices=list(Dialect), help="The database dialect to use.",
default=Dialect.sqlite)
parser.add_argument("--dialect", type=Dialect, choices=list(Dialect),
help="The database dialect to use, Default: mysql",
default=Dialect.mysql)
parser.add_argument("--apps", nargs="+", help="List of apps to get ddl for.")

SqlMigrateCommand.handle = mock_handle

def handle(self, *args, **options):
global current_dialect
current_dialect = options.get("dialect", Dialect.sqlite)
print(self.get_ddl())

def create_migrations(self):
try:
call_command(
"makemigrations",
"--no-input",
stdout=StringIO(),
stderr=StringIO()
)
except Exception as e:
traceback.print_exc()
self.stderr.write(f"failed to create migrations, {e}")
exit(1)
selected_apps = options.get("apps", None)
return self.get_ddl(selected_apps)

# Load migrations and get the sql statements describing the migrations.
def get_ddl(self):
def get_ddl(self, selected_apps):
ddl = ""
for app_name, migration_name in get_migrations():
for app_name, migration_name in get_migrations(selected_apps):
try:
out = StringIO()
call_command(
Expand Down
8 changes: 5 additions & 3 deletions atlas_provider_django/management/commands/migrations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.state import ProjectState
from django.db.migrations.loader import MigrationLoader
from django.apps import apps
from django.apps import apps as all_apps


# Creates the migrations of the installed apps from empty baseline and returns them as a dictionary
def get_migrations():
def get_migrations(apps=None):
autodetector = MigrationAutodetector(
ProjectState(),
ProjectState.from_apps(apps),
ProjectState.from_apps(all_apps),
)
loader = MigrationLoader(None, ignore_no_migrations=True)
changes = autodetector.changes(
Expand All @@ -18,5 +18,7 @@ def get_migrations():
)
migrations = {}
for app_label, app_migrations in changes.items():
if apps and app_label not in apps:
continue
migrations[(app_label, app_migrations[0].name)] = app_migrations[0]
return migrations
8 changes: 8 additions & 0 deletions atlas_provider_django/settings.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
INSTALLED_APPS = ["atlas_provider_django", "tests.app1", "tests.app2"]

# if there are no databases defined, the tests tear down will fail
DATABASES = {
"default": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": "atlas_provider_django.db",
}
}
22 changes: 22 additions & 0 deletions tests/expected_all_apps.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
BEGIN;
--
-- Create model Musician
--
CREATE TABLE `app1_musician` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `first_name` varchar(50) NOT NULL, `last_name` varchar(50) NOT NULL, `instrument` varchar(100) NOT NULL);
--
-- Create model Album
--
CREATE TABLE `app1_album` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(100) NOT NULL, `release_date` date NOT NULL, `num_stars` integer NOT NULL, `artist_id` bigint NOT NULL);
ALTER TABLE `app1_album` ADD CONSTRAINT `app1_album_artist_id_aed0987a_fk_app1_musician_id` FOREIGN KEY (`artist_id`) REFERENCES `app1_musician` (`id`);
COMMIT;
BEGIN;
--
-- Create model User
--
CREATE TABLE `app2_user` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `first_name` varchar(50) NOT NULL, `last_name` varchar(50) NULL, `roll` varchar(100) NOT NULL);
--
-- Create model Blog
--
CREATE TABLE `app2_blog` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(100) NOT NULL, `created_at` date NOT NULL, `num_stars` integer NOT NULL, `author_id` bigint NOT NULL);
ALTER TABLE `app2_blog` ADD CONSTRAINT `app2_blog_author_id_1675e606_fk_app2_user_id` FOREIGN KEY (`author_id`) REFERENCES `app2_user` (`id`);
COMMIT;
11 changes: 11 additions & 0 deletions tests/expected_app1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
BEGIN;
--
-- Create model Musician
--
CREATE TABLE `app1_musician` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `first_name` varchar(50) NOT NULL, `last_name` varchar(50) NOT NULL, `instrument` varchar(100) NOT NULL);
--
-- Create model Album
--
CREATE TABLE `app1_album` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(100) NOT NULL, `release_date` date NOT NULL, `num_stars` integer NOT NULL, `artist_id` bigint NOT NULL);
ALTER TABLE `app1_album` ADD CONSTRAINT `app1_album_artist_id_aed0987a_fk_app1_musician_id` FOREIGN KEY (`artist_id`) REFERENCES `app1_musician` (`id`);
COMMIT;
17 changes: 17 additions & 0 deletions tests/tests_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from django.test import TestCase
from django.core.management import call_command
from io import StringIO


class TestAtlasProviderDjango(TestCase):
def test_atlas_provider_django_all_apps(self):
out = StringIO()
call_command("atlas-provider-django", stdout=out)
with open("tests/expected_all_apps.sql", "r") as f:
self.assertEqual(out.getvalue(), f.read())

def test_atlas_provider_django_specific_app(self):
out = StringIO()
call_command("atlas-provider-django", "--app", "app1", stdout=out)
with open("tests/expected_app1.sql", "r") as f:
self.assertEqual(out.getvalue(), f.read())

0 comments on commit 8074450

Please sign in to comment.