From bab4912f12452bb9cec1987fe99ee27c85a0a0d1 Mon Sep 17 00:00:00 2001 From: Henri Maurer Date: Thu, 21 Mar 2024 16:28:52 +0000 Subject: [PATCH] Backport https://github.com/vitessio/vitess/pull/15275 --- go/vt/mysqlctl/schema.go | 18 ++++++------------ go/vt/vtexplain/vtexplain_vttablet.go | 8 ++++---- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 406b5c59499..757eab757bc 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -60,12 +60,6 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error { return mysqld.executeMysqlScript(params, strings.NewReader(sql)) } -func encodeEntityName(name string) string { - var buf strings.Builder - sqltypes.NewVarChar(name).EncodeSQL(&buf) - return buf.String() -} - // tableListSQL returns an IN clause "('t1', 't2'...) for a list of tables." func tableListSQL(tables []string) (string, error) { if len(tables) == 0 { @@ -74,7 +68,7 @@ func tableListSQL(tables []string) (string, error) { encodedTables := make([]string, len(tables)) for i, tableName := range tables { - encodedTables[i] = encodeEntityName(tableName) + encodedTables[i] = sqltypes.EncodeStringSQL(tableName) } return "(" + strings.Join(encodedTables, ", ") + ")", nil @@ -304,9 +298,9 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql if dbName == "" { dbName2 = "database()" } else { - dbName2 = encodeEntityName(dbName) + dbName2 = sqltypes.EncodeStringSQL(dbName) } - query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName))) + query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqltypes.EncodeStringSQL(sqlescape.UnescapeID(tableName))) qr, err := exec(query, -1, true) if err != nil { return "", err @@ -393,7 +387,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary' ORDER BY table_name, SEQ_IN_INDEX` - sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList) + sql = fmt.Sprintf(sql, sqltypes.EncodeStringSQL(dbName), tableList) qr, err := conn.ExecuteFetch(sql, len(tables)*100, true) if err != nil { return nil, err @@ -622,8 +616,8 @@ func (mysqld *Mysqld) GetPrimaryKeyEquivalentColumns(ctx context.Context, dbName ) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES' ORDER BY SEQ_IN_INDEX ASC` - encodedDbName := encodeEntityName(dbName) - encodedTable := encodeEntityName(table) + encodedDbName := sqltypes.EncodeStringSQL(dbName) + encodedTable := sqltypes.EncodeStringSQL(table) sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable) qr, err := conn.ExecuteFetch(sql, 1000, true) if err != nil { diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 4f0a3f7d102..9a82cc915bf 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -437,8 +437,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options) (*tablet } tEnv.addResult(query, tEnv.getResult(likeQuery)) - likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(likeTable)) - query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table)) + likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(likeTable)) + query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(table)) if tEnv.getResult(likeQuery) == nil { return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable) } @@ -477,7 +477,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options) (*tablet tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{ Fields: rowTypes, }) - query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table)) + query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(table)) tEnv.addResult(query, &sqltypes.Result{ Fields: colTypes, Rows: colValues, @@ -558,7 +558,7 @@ func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) { // Gen4 supports more complex queries so we now need to // handle multiple FROM clauses - tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From)) + tables := make([]*sqlparser.AliasedTableExpr, 0, len(selStmt.From)) for _, from := range selStmt.From { tables = append(tables, getTables(from)...) }