Skip to content

Commit

Permalink
🐛 ensure index names do not collide with table names
Browse files Browse the repository at this point in the history
  • Loading branch information
techouse committed Jan 13, 2024
1 parent bcaa34b commit 91bd38f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 2.1.8

* [FIX] ensure index names do not collide with table names

# 2.1.7

* [FIX] use more precise foreign key constraints
Expand Down
34 changes: 25 additions & 9 deletions mysql_to_sqlite3/transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,29 @@ def _build_create_table_sql(self, table_name: str) -> str:
""",
(self._mysql_database, table_name),
)
for index in self._mysql_cur_dict.fetchall():
mysql_indices: t.Sequence[t.Optional[t.Dict[str, ToPythonOutputTypes]]] = self._mysql_cur_dict.fetchall()
for index in mysql_indices:
if index is not None:
index_name: str
if isinstance(index["name"], bytes):
index_name = index["name"].decode()
elif isinstance(index["name"], str):
index_name = index["name"]
else:
index_name = str(index["name"])

# check if the index name collides with any table name
self._mysql_cur_dict.execute(
"""
SELECT COUNT(*)
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
""",
(self._mysql_database, index_name),
)
index_name_collision: t.Optional[t.Dict[str, ToPythonOutputTypes]] = self._mysql_cur_dict.fetchone()

columns: str = ""
if isinstance(index["columns"], bytes):
columns = index["columns"].decode()
Expand All @@ -421,14 +442,9 @@ def _build_create_table_sql(self, table_name: str) -> str:
else:
indices += """CREATE {unique} INDEX IF NOT EXISTS "{name}" ON "{table}" ({columns});""".format(
unique="UNIQUE" if index["unique"] in {1, "1"} else "",
name="{table}_{name}".format(
table=table_name,
name=index["name"].decode() if isinstance(index["name"], bytes) else index["name"],
)
if self._prefix_indices
else index["name"].decode()
if isinstance(index["name"], bytes)
else index["name"],
name=f"{table_name}_{index_name}"
if (index_name_collision is not None or self._prefix_indices)
else index_name,
table=table_name,
columns=", ".join(f'"{column}"' for column in columns.split(",")),
)
Expand Down

0 comments on commit 91bd38f

Please sign in to comment.