diff --git a/CHANGELOG.md b/CHANGELOG.md index a6dc590..c807316 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 2.1.8 + +* [FIX] ensure index names do not collide with table names + # 2.1.7 * [FIX] use more precise foreign key constraints diff --git a/mysql_to_sqlite3/transporter.py b/mysql_to_sqlite3/transporter.py index b17e3a0..628c0ae 100644 --- a/mysql_to_sqlite3/transporter.py +++ b/mysql_to_sqlite3/transporter.py @@ -405,8 +405,29 @@ def _build_create_table_sql(self, table_name: str) -> str: """, (self._mysql_database, table_name), ) - for index in self._mysql_cur_dict.fetchall(): + mysql_indices: t.Sequence[t.Optional[t.Dict[str, ToPythonOutputTypes]]] = self._mysql_cur_dict.fetchall() + for index in mysql_indices: if index is not None: + index_name: str + if isinstance(index["name"], bytes): + index_name = index["name"].decode() + elif isinstance(index["name"], str): + index_name = index["name"] + else: + index_name = str(index["name"]) + + # check if the index name collides with any table name + self._mysql_cur_dict.execute( + """ + SELECT COUNT(*) + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = %s + AND TABLE_NAME = %s + """, + (self._mysql_database, index_name), + ) + index_name_collision: t.Optional[t.Dict[str, ToPythonOutputTypes]] = self._mysql_cur_dict.fetchone() + columns: str = "" if isinstance(index["columns"], bytes): columns = index["columns"].decode() @@ -421,14 +442,9 @@ def _build_create_table_sql(self, table_name: str) -> str: else: indices += """CREATE {unique} INDEX IF NOT EXISTS "{name}" ON "{table}" ({columns});""".format( unique="UNIQUE" if index["unique"] in {1, "1"} else "", - name="{table}_{name}".format( - table=table_name, - name=index["name"].decode() if isinstance(index["name"], bytes) else index["name"], - ) - if self._prefix_indices - else index["name"].decode() - if isinstance(index["name"], bytes) - else index["name"], + name=f"{table_name}_{index_name}" + if (index_name_collision is not None or self._prefix_indices) + else index_name, table=table_name, columns=", ".join(f'"{column}"' for column in columns.split(",")), )