diff --git a/atlas_provider_django/management/commands/atlas-provider-django.py b/atlas_provider_django/management/commands/atlas-provider-django.py index 92c4fec..3b9fabf 100644 --- a/atlas_provider_django/management/commands/atlas-provider-django.py +++ b/atlas_provider_django/management/commands/atlas-provider-django.py @@ -21,6 +21,7 @@ class Dialect(str, Enum): mysql = "mysql" + mariadb = "mariadb" sqlite = "sqlite" postgresql = "postgresql" @@ -74,26 +75,41 @@ def has_native_uuid_field(self): return False +class MockMariaDBDatabaseFeatures(MySQLDatabaseFeatures): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def has_native_uuid_field(self): + return True + + # Returns the database connection wrapper for the given dialect. # Mocks some methods in order to get the sql statements without db connection. def get_connection_by_dialect(dialect): conn = None - if dialect == Dialect.sqlite: - conn = Sqlite3DatabaseWrapper({ - "ENGINE": "django.db.backends.sqlite3", - }, "sqlite3") - conn.SchemaEditorClass = MockSqliteSchemaEditor - elif dialect == Dialect.postgresql: - conn = PGDatabaseWrapper({ - "ENGINE": "django.db.backends.postgresql", - }, "postgresql") - conn.SchemaEditorClass = MockPGDatabaseSchemaEditor - elif dialect == Dialect.mysql: - conn = MySQLDatabaseWrapper({ - "ENGINE": "django.db.backends.mysql", - }, "mysql") - conn.SchemaEditorClass = MockMySQLDatabaseSchemaEditor - conn.features = MockMySQLDatabaseFeatures + match dialect: + case Dialect.sqlite: + conn = Sqlite3DatabaseWrapper({ + "ENGINE": "django.db.backends.sqlite3", + }, "sqlite3") + conn.SchemaEditorClass = MockSqliteSchemaEditor + case Dialect.postgresql: + conn = PGDatabaseWrapper({ + "ENGINE": "django.db.backends.postgresql", + }, "postgresql") + conn.SchemaEditorClass = MockPGDatabaseSchemaEditor + case Dialect.mysql: + conn = MySQLDatabaseWrapper({ + "ENGINE": "django.db.backends.mysql", + }, "mysql") + conn.SchemaEditorClass = MockMySQLDatabaseSchemaEditor + conn.features = MockMySQLDatabaseFeatures(conn) + case Dialect.mariadb: + conn = MySQLDatabaseWrapper({ + "ENGINE": "django.db.backends.mysql", + }, "mysql") + conn.SchemaEditorClass = MockMySQLDatabaseSchemaEditor + conn.features = MockMariaDBDatabaseFeatures(conn) return conn