diff --git a/src/sqlite3_to_mysql/mysql_utils.py b/src/sqlite3_to_mysql/mysql_utils.py index c5523ce..08d8d45 100644 --- a/src/sqlite3_to_mysql/mysql_utils.py +++ b/src/sqlite3_to_mysql/mysql_utils.py @@ -6,7 +6,6 @@ from mysql.connector import CharacterSet from mysql.connector.charsets import MYSQL_CHARACTER_SETS from packaging import version -from packaging.version import Version # Shamelessly copied from SQLAlchemy's dialects/mysql/__init__.py @@ -112,30 +111,32 @@ def get_mysql_version(version_string: str) -> version.Version: def check_mysql_json_support(version_string: str) -> bool: """Check for MySQL JSON support.""" - mysql_version: Version = get_mysql_version(version_string) - if version_string.lower().endswith("-mariadb"): - if mysql_version.major >= 10 and mysql_version.minor >= 2 and mysql_version.micro >= 7: - return True - else: - if mysql_version.major >= 8: - return True - if mysql_version.minor >= 7 and mysql_version.micro >= 8: - return True - return False + mysql_version: version.Version = get_mysql_version(version_string) + if "-mariadb" in version_string.lower(): + return mysql_version >= version.parse("10.2.7") + return mysql_version >= version.parse("5.7.8") + + +def check_mysql_values_alias_support(version_string: str) -> bool: + """Check for VALUES alias support. + + Returns: + bool: True if VALUES alias is supported (MySQL 8.0.19+), False for MariaDB + or older MySQL versions. + """ + mysql_version: version.Version = get_mysql_version(version_string) + if "-mariadb" in version_string.lower(): + return False + # Only MySQL 8.0.19 and later support VALUES alias + return mysql_version >= version.parse("8.0.19") def check_mysql_fulltext_support(version_string: str) -> bool: """Check for FULLTEXT indexing support.""" - mysql_version: Version = get_mysql_version(version_string) - if version_string.lower().endswith("-mariadb"): - if mysql_version.major >= 10 and mysql_version.minor >= 0 and mysql_version.micro >= 5: - return True - else: - if mysql_version.major >= 8: - return True - if mysql_version.minor >= 6: - return True - return False + mysql_version: version.Version = get_mysql_version(version_string) + if "-mariadb" in version_string.lower(): + return mysql_version >= version.parse("10.0.5") + return mysql_version >= version.parse("5.6.0") def safe_identifier_length(identifier_name: str, max_length: int = 64) -> str: diff --git a/src/sqlite3_to_mysql/transporter.py b/src/sqlite3_to_mysql/transporter.py index 6e034b5..9240533 100644 --- a/src/sqlite3_to_mysql/transporter.py +++ b/src/sqlite3_to_mysql/transporter.py @@ -39,6 +39,7 @@ MYSQL_TEXT_COLUMN_TYPES_WITH_JSON, check_mysql_fulltext_support, check_mysql_json_support, + check_mysql_values_alias_support, safe_identifier_length, ) from .types import SQLite3toMySQLAttributes, SQLite3toMySQLParams @@ -88,7 +89,7 @@ def __init__(self, **kwargs: tx.Unpack[SQLite3toMySQLParams]): self._mysql_database = kwargs.get("mysql_database", "transfer") or "transfer" - self._mysql_insert_method = str(kwargs.get("mysql_integer_type", "IGNORE")).upper() + self._mysql_insert_method = str(kwargs.get("mysql_insert_method", "IGNORE")).upper() if self._mysql_insert_method not in MYSQL_INSERT_METHOD: self._mysql_insert_method = "IGNORE" @@ -722,21 +723,26 @@ def transfer(self) -> None: columns: t.List[str] = [ safe_identifier_length(column[0]) for column in self._sqlite_cur.description ] + sql: str if self._mysql_insert_method.upper() == "UPDATE": - sql: str = ( - """ + sql = """ INSERT INTO `{table}` ({fields}) - VALUES ({placeholders}) AS `__new__` + {values_clause} ON DUPLICATE KEY UPDATE {field_updates} """.format( - table=safe_identifier_length(table["name"]), - fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns), - placeholders=("%s, " * len(columns)).rstrip(" ,"), - field_updates=("`{}`=`__new__`.`{}`, " * len(columns)) - .rstrip(" ,") - .format(*list(chain.from_iterable((column, column) for column in columns))), - ) + table=safe_identifier_length(table["name"]), + fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns), + values_clause=( + "VALUES ({placeholders}) AS `__new__`" + if check_mysql_values_alias_support(self._mysql_version) + else "VALUES ({placeholders})" + ).format(placeholders=("%s, " * len(columns)).rstrip(" ,")), + field_updates=( + ("`{}`=`__new__`.`{}`, " * len(columns)).rstrip(" ,") + if check_mysql_values_alias_support(self._mysql_version) + else ("`{}`=`{}`, " * len(columns)).rstrip(" ,") + ).format(*list(chain.from_iterable((column, column) for column in columns))), ) else: sql = """ diff --git a/tests/func/sqlite3_to_mysql_test.py b/tests/func/sqlite3_to_mysql_test.py index a259554..bb887cb 100644 --- a/tests/func/sqlite3_to_mysql_test.py +++ b/tests/func/sqlite3_to_mysql_test.py @@ -59,7 +59,7 @@ def test_valid_sqlite_file_and_valid_mysql_credentials( mysql_credentials: MySQLCredentials, helpers: Helpers, quiet: bool, - ): + ) -> None: with helpers.not_raises(FileNotFoundError): SQLite3toMySQL( # type: ignore sqlite_file=sqlite_database, diff --git a/tests/unit/mysql_utils_test.py b/tests/unit/mysql_utils_test.py new file mode 100644 index 0000000..34379bc --- /dev/null +++ b/tests/unit/mysql_utils_test.py @@ -0,0 +1,87 @@ +import pytest +from packaging.version import Version + +from sqlite3_to_mysql.mysql_utils import ( + check_mysql_fulltext_support, + check_mysql_json_support, + check_mysql_values_alias_support, + get_mysql_version, + safe_identifier_length, +) + + +class TestMySQLUtils: + @pytest.mark.parametrize( + "version_string,expected", + [ + ("5.7.7", Version("5.7.7")), + ("5.7.8", Version("5.7.8")), + ("8.0.0", Version("8.0.0")), + ("9.0.0", Version("9.0.0")), + ("10.2.6-mariadb", Version("10.2.6")), + ("10.2.7-mariadb", Version("10.2.7")), + ("11.4.0-mariadb", Version("11.4.0")), + ], + ) + def test_get_mysql_version(self, version_string: str, expected: Version) -> None: + assert get_mysql_version(version_string) == expected + + @pytest.mark.parametrize( + "version_string,expected", + [ + ("5.7.7", False), + ("5.7.8", True), + ("8.0.0", True), + ("9.0.0", True), + ("10.2.6-mariadb", False), + ("10.2.7-mariadb", True), + ("11.4.0-mariadb", True), + ], + ) + def test_check_mysql_json_support(self, version_string: str, expected: bool) -> None: + assert check_mysql_json_support(version_string) == expected + + @pytest.mark.parametrize( + "version_string,expected", + [ + ("5.7.8", False), + ("8.0.0", False), + ("8.0.18", False), + ("8.0.19", True), + ("9.0.0", True), + ("10.2.6-mariadb", False), + ("10.2.7-mariadb", False), + ("11.4.0-mariadb", False), + ], + ) + def test_check_mysql_values_alias_support(self, version_string: str, expected: bool) -> None: + assert check_mysql_values_alias_support(version_string) == expected + + @pytest.mark.parametrize( + "version_string,expected", + [ + ("5.0.0", False), + ("5.5.0", False), + ("5.6.0", True), + ("8.0.0", True), + ("10.0.4-mariadb", False), + ("10.0.5-mariadb", True), + ("10.2.6-mariadb", True), + ("11.4.0-mariadb", True), + ], + ) + def test_check_mysql_fulltext_support(self, version_string: str, expected: bool) -> None: + assert check_mysql_fulltext_support(version_string) == expected + + @pytest.mark.parametrize( + "identifier,expected", + [ + ("a" * 67, "a" * 64), + ("a" * 66, "a" * 64), + ("a" * 65, "a" * 64), + ("a" * 64, "a" * 64), + ("a" * 63, "a" * 63), + ], + ) + def test_safe_identifier_length(self, identifier: str, expected: str) -> None: + assert safe_identifier_length(identifier) == expected