Skip to content

Commit

Permalink
add postgres and mysql implementations (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenlu authored Jan 23, 2024
1 parent f978889 commit 2d4dde7
Showing 1 changed file with 48 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from django.core.management.commands.sqlmigrate import Command as SqlMigrateCommand
from django.db.backends.sqlite3.base import DatabaseWrapper as Sqlite3DatabaseWrapper
from django.db.backends.sqlite3.schema import DatabaseSchemaEditor as SqliteSchemaEditor
from django.db.backends.postgresql.schema import DatabaseSchemaEditor as PGDatabaseSchemaEditor
from django.db.backends.postgresql.base import DatabaseWrapper as PGDatabaseWrapper
from django.db.backends.mysql.base import DatabaseWrapper as MySQLDatabaseWrapper
from django.db.backends.mysql.schema import DatabaseSchemaEditor as MySQLDatabaseSchemaEditor
from django.db.backends.mysql.features import DatabaseFeatures as MySQLDatabaseFeatures


class Dialect(str, Enum):
Expand Down Expand Up @@ -38,15 +43,57 @@ def __exit__(self, exc_type, exc_value, traceback):
return super(SqliteSchemaEditor, self).__exit__(exc_type, exc_value, traceback)


class MockPGDatabaseSchemaEditor(PGDatabaseSchemaEditor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute(self, sql, params=()):
return super(PGDatabaseSchemaEditor, self).execute(sql, params)


class MockMySQLDatabaseSchemaEditor(MySQLDatabaseSchemaEditor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Override the method of MySQLDatabaseSchemaEditor since it checks the storage engine.
# We assume that the storage engine is InnoDB.
def _field_should_be_indexed(self, model, field):
if not super(MySQLDatabaseSchemaEditor, self)._field_should_be_indexed(model, field):
return False
if field.get_internal_type() == "ForeignKey" and field.db_constraint:
return False
return not self._is_limited_data_type(field)


class MockMySQLDatabaseFeatures(MySQLDatabaseFeatures):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def has_native_uuid_field(self):
return False


# Returns the database connection wrapper for the given dialect.
# Mocks some methods in order to get the sql statements without db connection.
def get_connection_by_dialect(dialect):
conn = None
if dialect == Dialect.sqlite:
conn = Sqlite3DatabaseWrapper({
"ENGINE": "django.db.backends.sqlite3",
}, "sqlite3")
conn.SchemaEditorClass = MockSqliteSchemaEditor
return conn
elif dialect == Dialect.postgresql:
conn = PGDatabaseWrapper({
"ENGINE": "django.db.backends.postgresql",
}, "postgresql")
conn.SchemaEditorClass = MockPGDatabaseSchemaEditor
elif dialect == Dialect.mysql:
conn = MySQLDatabaseWrapper({
"ENGINE": "django.db.backends.mysql",
}, "mysql")
conn.SchemaEditorClass = MockMySQLDatabaseSchemaEditor
conn.features = MockMySQLDatabaseFeatures
return conn


# MockMigrationLoader loads migrations without db connection.
Expand Down

0 comments on commit 2d4dde7

Please sign in to comment.