Skip to content

Commit

Permalink
🐛 fix typo in --mysql-insert-method parameter (#130)
Browse files Browse the repository at this point in the history
* fixed mysql insert method parameter

* 🐛 fix MariaDB INSERT ON DUPLICATE KEY UPDATE

---------

Co-authored-by: Kuba Suder <[email protected]>
  • Loading branch information
techouse and mackuba authored Oct 26, 2024
1 parent d5e5626 commit 3fc65aa
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 33 deletions.
43 changes: 22 additions & 21 deletions src/sqlite3_to_mysql/mysql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from mysql.connector import CharacterSet
from mysql.connector.charsets import MYSQL_CHARACTER_SETS
from packaging import version
from packaging.version import Version


# Shamelessly copied from SQLAlchemy's dialects/mysql/__init__.py
Expand Down Expand Up @@ -112,30 +111,32 @@ def get_mysql_version(version_string: str) -> version.Version:

def check_mysql_json_support(version_string: str) -> bool:
"""Check for MySQL JSON support."""
mysql_version: Version = get_mysql_version(version_string)
if version_string.lower().endswith("-mariadb"):
if mysql_version.major >= 10 and mysql_version.minor >= 2 and mysql_version.micro >= 7:
return True
else:
if mysql_version.major >= 8:
return True
if mysql_version.minor >= 7 and mysql_version.micro >= 8:
return True
return False
mysql_version: version.Version = get_mysql_version(version_string)
if "-mariadb" in version_string.lower():
return mysql_version >= version.parse("10.2.7")
return mysql_version >= version.parse("5.7.8")


def check_mysql_values_alias_support(version_string: str) -> bool:
"""Check for VALUES alias support.
Returns:
bool: True if VALUES alias is supported (MySQL 8.0.19+), False for MariaDB
or older MySQL versions.
"""
mysql_version: version.Version = get_mysql_version(version_string)
if "-mariadb" in version_string.lower():
return False
# Only MySQL 8.0.19 and later support VALUES alias
return mysql_version >= version.parse("8.0.19")


def check_mysql_fulltext_support(version_string: str) -> bool:
"""Check for FULLTEXT indexing support."""
mysql_version: Version = get_mysql_version(version_string)
if version_string.lower().endswith("-mariadb"):
if mysql_version.major >= 10 and mysql_version.minor >= 0 and mysql_version.micro >= 5:
return True
else:
if mysql_version.major >= 8:
return True
if mysql_version.minor >= 6:
return True
return False
mysql_version: version.Version = get_mysql_version(version_string)
if "-mariadb" in version_string.lower():
return mysql_version >= version.parse("10.0.5")
return mysql_version >= version.parse("5.6.0")


def safe_identifier_length(identifier_name: str, max_length: int = 64) -> str:
Expand Down
28 changes: 17 additions & 11 deletions src/sqlite3_to_mysql/transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
MYSQL_TEXT_COLUMN_TYPES_WITH_JSON,
check_mysql_fulltext_support,
check_mysql_json_support,
check_mysql_values_alias_support,
safe_identifier_length,
)
from .types import SQLite3toMySQLAttributes, SQLite3toMySQLParams
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(self, **kwargs: tx.Unpack[SQLite3toMySQLParams]):

self._mysql_database = kwargs.get("mysql_database", "transfer") or "transfer"

self._mysql_insert_method = str(kwargs.get("mysql_integer_type", "IGNORE")).upper()
self._mysql_insert_method = str(kwargs.get("mysql_insert_method", "IGNORE")).upper()
if self._mysql_insert_method not in MYSQL_INSERT_METHOD:
self._mysql_insert_method = "IGNORE"

Expand Down Expand Up @@ -722,21 +723,26 @@ def transfer(self) -> None:
columns: t.List[str] = [
safe_identifier_length(column[0]) for column in self._sqlite_cur.description
]
sql: str
if self._mysql_insert_method.upper() == "UPDATE":
sql: str = (
"""
sql = """
INSERT
INTO `{table}` ({fields})
VALUES ({placeholders}) AS `__new__`
{values_clause}
ON DUPLICATE KEY UPDATE {field_updates}
""".format(
table=safe_identifier_length(table["name"]),
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
placeholders=("%s, " * len(columns)).rstrip(" ,"),
field_updates=("`{}`=`__new__`.`{}`, " * len(columns))
.rstrip(" ,")
.format(*list(chain.from_iterable((column, column) for column in columns))),
)
table=safe_identifier_length(table["name"]),
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
values_clause=(
"VALUES ({placeholders}) AS `__new__`"
if check_mysql_values_alias_support(self._mysql_version)
else "VALUES ({placeholders})"
).format(placeholders=("%s, " * len(columns)).rstrip(" ,")),
field_updates=(
("`{}`=`__new__`.`{}`, " * len(columns)).rstrip(" ,")
if check_mysql_values_alias_support(self._mysql_version)
else ("`{}`=`{}`, " * len(columns)).rstrip(" ,")
).format(*list(chain.from_iterable((column, column) for column in columns))),
)
else:
sql = """
Expand Down
2 changes: 1 addition & 1 deletion tests/func/sqlite3_to_mysql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_valid_sqlite_file_and_valid_mysql_credentials(
mysql_credentials: MySQLCredentials,
helpers: Helpers,
quiet: bool,
):
) -> None:
with helpers.not_raises(FileNotFoundError):
SQLite3toMySQL( # type: ignore
sqlite_file=sqlite_database,
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/mysql_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
from packaging.version import Version

from sqlite3_to_mysql.mysql_utils import (
check_mysql_fulltext_support,
check_mysql_json_support,
check_mysql_values_alias_support,
get_mysql_version,
safe_identifier_length,
)


class TestMySQLUtils:
@pytest.mark.parametrize(
"version_string,expected",
[
("5.7.7", Version("5.7.7")),
("5.7.8", Version("5.7.8")),
("8.0.0", Version("8.0.0")),
("9.0.0", Version("9.0.0")),
("10.2.6-mariadb", Version("10.2.6")),
("10.2.7-mariadb", Version("10.2.7")),
("11.4.0-mariadb", Version("11.4.0")),
],
)
def test_get_mysql_version(self, version_string: str, expected: Version) -> None:
assert get_mysql_version(version_string) == expected

@pytest.mark.parametrize(
"version_string,expected",
[
("5.7.7", False),
("5.7.8", True),
("8.0.0", True),
("9.0.0", True),
("10.2.6-mariadb", False),
("10.2.7-mariadb", True),
("11.4.0-mariadb", True),
],
)
def test_check_mysql_json_support(self, version_string: str, expected: bool) -> None:
assert check_mysql_json_support(version_string) == expected

@pytest.mark.parametrize(
"version_string,expected",
[
("5.7.8", False),
("8.0.0", False),
("8.0.18", False),
("8.0.19", True),
("9.0.0", True),
("10.2.6-mariadb", False),
("10.2.7-mariadb", False),
("11.4.0-mariadb", False),
],
)
def test_check_mysql_values_alias_support(self, version_string: str, expected: bool) -> None:
assert check_mysql_values_alias_support(version_string) == expected

@pytest.mark.parametrize(
"version_string,expected",
[
("5.0.0", False),
("5.5.0", False),
("5.6.0", True),
("8.0.0", True),
("10.0.4-mariadb", False),
("10.0.5-mariadb", True),
("10.2.6-mariadb", True),
("11.4.0-mariadb", True),
],
)
def test_check_mysql_fulltext_support(self, version_string: str, expected: bool) -> None:
assert check_mysql_fulltext_support(version_string) == expected

@pytest.mark.parametrize(
"identifier,expected",
[
("a" * 67, "a" * 64),
("a" * 66, "a" * 64),
("a" * 65, "a" * 64),
("a" * 64, "a" * 64),
("a" * 63, "a" * 63),
],
)
def test_safe_identifier_length(self, identifier: str, expected: str) -> None:
assert safe_identifier_length(identifier) == expected

0 comments on commit 3fc65aa

Please sign in to comment.