diff --git a/cdc/model/sink.go b/cdc/model/sink.go index 019728f40fd..3632072bfbd 100644 --- a/cdc/model/sink.go +++ b/cdc/model/sink.go @@ -1347,3 +1347,34 @@ func (x ColumnDataX) GetDefaultValue() interface{} { func (x ColumnDataX) GetColumnInfo() *model.ColumnInfo { return x.info } + +// Column2ColumnDataForTest is for tests. +func Column2ColumnDataForTest(columns []*Column) ([]*ColumnData, *TableInfo) { + info := &TableInfo{ + TableInfo: &model.TableInfo{ + Columns: make([]*model.ColumnInfo, len(columns)), + }, + ColumnsFlag: make(map[int64]*ColumnFlagType, len(columns)), + columnsOffset: make(map[int64]int), + } + colDatas := make([]*ColumnData, 0, len(columns)) + + for i, column := range columns { + var columnID int64 = int64(i) + info.columnsOffset[columnID] = i + + info.Columns[i] = &model.ColumnInfo{} + info.Columns[i].Name.O = column.Name + info.Columns[i].SetType(column.Type) + info.Columns[i].SetCharset(column.Charset) + info.Columns[i].SetCollate(column.Collation) + info.Columns[i].DefaultValue = column.Default + + info.ColumnsFlag[columnID] = new(ColumnFlagType) + *info.ColumnsFlag[columnID] = column.Flag + + colDatas = append(colDatas, &ColumnData{ColumnID: columnID, Value: column.Value}) + } + + return colDatas, info +} diff --git a/cdc/sink/dmlsink/txn/event_test.go b/cdc/sink/dmlsink/txn/event_test.go index bdbb5cd7db9..31e2e05bf16 100644 --- a/cdc/sink/dmlsink/txn/event_test.go +++ b/cdc/sink/dmlsink/txn/event_test.go @@ -25,24 +25,24 @@ import ( func TestGenKeyListCaseInSensitive(t *testing.T) { t.Parallel() - columns := []*model.Column{ + columns, tb := model.Column2ColumnDataForTest([]*model.Column{ { Value: "XyZ", Type: mysql.TypeVarchar, Collation: "utf8_unicode_ci", }, - } + }) - first := genKeyList(columns, 0, []int{0}, 1) + first := genKeyList(columns, tb, 0, []int{0}, 1) - columns = []*model.Column{ + columns, tb = model.Column2ColumnDataForTest([]*model.Column{ { Value: "xYZ", Type: mysql.TypeVarchar, Collation: "utf8_unicode_ci", }, - } - second := genKeyList(columns, 0, []int{0}, 1) + }) + second := genKeyList(columns, tb, 0, []int{0}, 1) require.Equal(t, first, second) } diff --git a/cdc/sink/dmlsink/txn/mysql/dml.go b/cdc/sink/dmlsink/txn/mysql/dml.go index abbe1067874..92802c9ad41 100644 --- a/cdc/sink/dmlsink/txn/mysql/dml.go +++ b/cdc/sink/dmlsink/txn/mysql/dml.go @@ -157,29 +157,25 @@ func prepareDelete(quoteTable string, cols []*model.ColumnData, tb *model.TableI // whereSlice builds a parametric WHERE clause as following // sql: `WHERE {} = ? AND {} > ?` func whereSlice(cols []*model.ColumnData, tb *model.TableInfo, forceReplicate bool) (colNames []string, args []interface{}) { - // If no explicit row id but force replicate, use all key-values in where condition. + // Try to use unique key values when available + for _, col := range cols { + colx := model.GetColumnDataX(col, tb) + if colx.ColumnData == nil || !colx.GetFlag().IsHandleKey() { + continue + } + colNames = append(colNames, colx.GetName()) + args = appendQueryArgs(args, colx) + } + // if no explicit row id but force replicate, use all key-values in where condition if len(colNames) == 0 && forceReplicate { colNames = make([]string, 0, len(cols)) args = make([]interface{}, 0, len(cols)) for _, col := range cols { colx := model.GetColumnDataX(col, tb) - if colx.ColumnData == nil { - continue - } - colNames = append(colNames, colx.GetName()) - args = appendQueryArgs(args, colx) - } - } else { // Try to use unique key values when available. - for _, col := range cols { - colx := model.GetColumnDataX(col, tb) - if colx.ColumnData == nil || !colx.GetFlag().IsHandleKey() { - continue - } colNames = append(colNames, colx.GetName()) args = appendQueryArgs(args, colx) } } - return } diff --git a/cdc/sink/dmlsink/txn/mysql/dml_test.go b/cdc/sink/dmlsink/txn/mysql/dml_test.go index dab7c4b104f..20f22d63c2b 100644 --- a/cdc/sink/dmlsink/txn/mysql/dml_test.go +++ b/cdc/sink/dmlsink/txn/mysql/dml_test.go @@ -250,7 +250,9 @@ func TestPrepareUpdate(t *testing.T) { }, } for _, tc := range testCases { - query, args := prepareUpdate(tc.quoteTable, tc.preCols, tc.cols, false) + preDatas, info := model.Column2ColumnDataForTest(tc.preCols) + datas, _ := model.Column2ColumnDataForTest(tc.cols) + query, args := prepareUpdate(tc.quoteTable, preDatas, datas, info, false) require.Equal(t, tc.expectedSQL, query) require.Equal(t, tc.expectedArgs, args) } @@ -392,7 +394,8 @@ func TestPrepareDelete(t *testing.T) { }, } for _, tc := range testCases { - query, args := prepareDelete(tc.quoteTable, tc.preCols, false) + preDatas, info := model.Column2ColumnDataForTest(tc.preCols) + query, args := prepareDelete(tc.quoteTable, preDatas, info, false) require.Equal(t, tc.expectedSQL, query) require.Equal(t, tc.expectedArgs, args) } @@ -601,9 +604,10 @@ func TestWhereSlice(t *testing.T) { expectedArgs: []interface{}{1, "你好", 100}, }, } - for _, tc := range testCases { - colNames, args := whereSlice(tc.cols, tc.forceReplicate) - require.Equal(t, tc.expectedColNames, colNames) + for i, tc := range testCases { + datas, info := model.Column2ColumnDataForTest(tc.cols) + colNames, args := whereSlice(datas, info, tc.forceReplicate) + require.Equal(t, tc.expectedColNames, colNames, "case %d fails", i) require.Equal(t, tc.expectedArgs, args) } } @@ -713,7 +717,8 @@ func TestMapReplace(t *testing.T) { for _, tc := range testCases { // multiple times to verify the stability of column sequence in query string for i := 0; i < 10; i++ { - query, args := prepareReplace(tc.quoteTable, tc.cols, false, false) + datas, info := model.Column2ColumnDataForTest(tc.cols) + query, args := prepareReplace(tc.quoteTable, datas, info, false, false) require.Equal(t, tc.expectedQuery, query) require.Equal(t, tc.expectedArgs, args) }