Skip to content

Commit

Permalink
ticdc: Support Vector data type (#11620)
Browse files Browse the repository at this point in the history
ref #11530
  • Loading branch information
wk989898 committed Sep 24, 2024
1 parent d3aef11 commit 021fd64
Show file tree
Hide file tree
Showing 37 changed files with 672 additions and 70 deletions.
3 changes: 3 additions & 0 deletions cdc/entry/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ func unflatten(datum types.Datum, ft *types.FieldType, loc *time.Location) (type
byteSize := (ft.GetFlen() + 7) >> 3
datum.SetUint64(0)
datum.SetMysqlBit(types.NewBinaryLiteralFromUint(val, byteSize))
case mysql.TypeTiDBVectorFloat32:
datum.SetVectorFloat32(types.ZeroVectorFloat32)
return datum, nil
}
return datum, nil
}
5 changes: 5 additions & 0 deletions cdc/entry/mounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ func newDatum(value interface{}, ft types.FieldType) (types.Datum, error) {
return types.NewFloat32Datum(value.(float32)), nil
case mysql.TypeDouble:
return types.NewFloat64Datum(value.(float64)), nil
case mysql.TypeTiDBVectorFloat32:
return types.NewVectorFloat32Datum(value.(types.VectorFloat32)), nil
default:
log.Panic("unexpected mysql type found", zap.Any("type", ft.GetType()))
}
Expand Down Expand Up @@ -888,6 +890,9 @@ func formatColVal(datum types.Datum, col *timodel.ColumnInfo) (
}
const sizeOfV = unsafe.Sizeof(v)
return v, int(sizeOfV), warn, nil
case mysql.TypeTiDBVectorFloat32:
b := datum.GetVectorFloat32()
return b, b.Len(), "", nil
default:
// NOTICE: GetValue() may return some types that go sql not support, which will cause sink DML fail
// Make specified convert upper if you need
Expand Down
9 changes: 9 additions & 0 deletions cdc/entry/mounter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1682,4 +1682,13 @@ func TestFormatColVal(t *testing.T) {
require.NoError(t, err)
require.Equal(t, float32(0), value)
require.NotZero(t, warn)

vector, _ := types.ParseVectorFloat32("[1,2,3,4,5]")
ftTypeVector := types.NewFieldType(mysql.TypeTiDBVectorFloat32)
col = &timodel.ColumnInfo{FieldType: *ftTypeVector}
datum.SetVectorFloat32(vector)
value, _, warn, err = formatColVal(datum, col)
require.NoError(t, err)
require.Equal(t, vector, value)
require.Zero(t, warn)
}
64 changes: 64 additions & 0 deletions cdc/sink/ddlsink/mysql/format_ddl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql

import (
"bytes"

"github.com/pingcap/log"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/mysql"
"go.uber.org/zap"
)

type visiter struct{}

func (f *visiter) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
switch v := n.(type) {
case *ast.ColumnDef:
if v.Tp != nil {
switch v.Tp.GetType() {
case mysql.TypeTiDBVectorFloat32:
v.Tp.SetType(mysql.TypeLongBlob)
v.Tp.SetCharset("")
v.Tp.SetCollate("")
v.Tp.SetFlen(-1)
v.Options = []*ast.ColumnOption{} // clear COMMENT
}
}
}
return n, false
}

func (f *visiter) Leave(n ast.Node) (node ast.Node, ok bool) {
return n, true
}

func formatQuery(sql string) string {
p := parser.New()
stmt, err := p.ParseOneStmt(sql, "", "")
if err != nil {
log.Error("format query parse one stmt failed", zap.Error(err))
}
stmt.Accept(&visiter{})

buf := new(bytes.Buffer)
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, buf)
if err = stmt.Restore(restoreCtx); err != nil {
log.Error("format query restore failed", zap.Error(err))
}
return buf.String()
}
47 changes: 47 additions & 0 deletions cdc/sink/ddlsink/mysql/format_ddl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql

import (
"bytes"
"testing"

"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/stretchr/testify/require"
)

func TestFormatQuery(t *testing.T) {
sql := "CREATE TABLE `test` (`id` INT PRIMARY KEY,`data` VECTOR(5))"
expectSQL := "CREATE TABLE `test` (`id` INT PRIMARY KEY,`data` LONGTEXT)"
p := parser.New()
stmt, err := p.ParseOneStmt(sql, "", "")
require.NoError(t, err)
stmt.Accept(&visiter{})

buf := new(bytes.Buffer)
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, buf)
err = stmt.Restore(restoreCtx)
require.NoError(t, err)
require.Equal(t, buf.String(), expectSQL)
}

func BenchmarkFormatQuery(b *testing.B) {
sql := "CREATE TABLE `test` (`id` INT PRIMARY KEY,`data` LONGTEXT)"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
formatQuery(sql)
}
}
38 changes: 38 additions & 0 deletions cdc/sink/ddlsink/mysql/mysql_ddl_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ import (
"net/url"
"time"

"github.com/coreos/go-semver/semver"
lru "github.com/hashicorp/golang-lru"
"github.com/pingcap/failpoint"
"github.com/pingcap/log"
"github.com/pingcap/tidb/br/pkg/version"
"github.com/pingcap/tidb/dumpling/export"
timodel "github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tiflow/cdc/model"
"github.com/pingcap/tiflow/cdc/sink/ddlsink"
Expand All @@ -42,6 +45,8 @@ const (

// networkDriftDuration is used to construct a context timeout for database operations.
networkDriftDuration = 5 * time.Second

defaultSupportVectorVersion = "8.3.0"
)

// GetDBConnImpl is the implementation of pmysql.IDBConnectionFactory.
Expand All @@ -66,6 +71,8 @@ type DDLSink struct {
// is running in downstream.
// map: model.TableName -> timodel.ActionType
lastExecutedNormalDDLCache *lru.Cache

needFormat bool
}

// NewDDLSink creates a new DDLSink.
Expand Down Expand Up @@ -102,12 +109,14 @@ func NewDDLSink(
if err != nil {
return nil, err
}

m := &DDLSink{
id: changefeedID,
db: db,
cfg: cfg,
statistics: metrics.NewStatistics(changefeedID, sink.TxnSink),
lastExecutedNormalDDLCache: lruCache,
needFormat: needFormatDDL(db, cfg),
}

log.Info("MySQL DDL sink is created",
Expand Down Expand Up @@ -195,6 +204,14 @@ func (m *DDLSink) execDDL(pctx context.Context, ddl *model.DDLEvent) error {

shouldSwitchDB := needSwitchDB(ddl)

// Convert vector type to string type for unsupport database
if m.needFormat {
if newQuery := formatQuery(ddl.Query); newQuery != ddl.Query {
log.Warn("format ddl query", zap.String("newQuery", newQuery), zap.String("query", ddl.Query), zap.String("collate", ddl.Collate), zap.String("charset", ddl.Charset))
ddl.Query = newQuery
}
}

failpoint.Inject("MySQLSinkExecDDLDelay", func() {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -268,6 +285,27 @@ func needSwitchDB(ddl *model.DDLEvent) bool {
return true
}

// needFormatDDL checks vector type support
func needFormatDDL(db *sql.DB, cfg *pmysql.Config) bool {
if !cfg.HasVectorType {
log.Warn("please set `has-vector-type` to be true if a column is vector type when the downstream is not TiDB or TiDB version less than specify version",
zap.Any("hasVectorType", cfg.HasVectorType), zap.Any("supportVectorVersion", defaultSupportVectorVersion))
return false
}
versionInfo, err := export.SelectVersion(db)
if err != nil {
log.Warn("fail to get version", zap.Error(err), zap.Bool("isTiDB", cfg.IsTiDB))
return false
}
serverInfo := version.ParseServerInfo(versionInfo)
version := semver.New(defaultSupportVectorVersion)
if !cfg.IsTiDB || serverInfo.ServerVersion.LessThan(*version) {
log.Error("downstream unsupport vector type. it will be converted to longtext", zap.String("version", serverInfo.ServerVersion.String()), zap.String("supportVectorVersion", defaultSupportVectorVersion), zap.Bool("isTiDB", cfg.IsTiDB))
return true
}
return false
}

// WriteCheckpointTs does nothing.
func (m *DDLSink) WriteCheckpointTs(_ context.Context, _ uint64, _ []*model.TableInfo) error {
// Only for RowSink for now.
Expand Down
17 changes: 9 additions & 8 deletions cdc/sink/dmlsink/txn/mysql/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strings"

"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tiflow/cdc/model"
"github.com/pingcap/tiflow/pkg/quotes"
)
Expand Down Expand Up @@ -109,16 +110,16 @@ func prepareReplace(
// will automatically set `_binary` charset for that column, which is not expected.
// See https://github.com/go-sql-driver/mysql/blob/ce134bfc/connection.go#L267
func appendQueryArgs(args []interface{}, col *model.Column) []interface{} {
if col.Charset != "" && col.Charset != charset.CharsetBin {
colValBytes, ok := col.Value.([]byte)
if ok {
args = append(args, string(colValBytes))
} else {
args = append(args, col.Value)
switch v := col.Value.(type) {
case []byte:
if col.Charset != "" && col.Charset != charset.CharsetBin {
args = append(args, string(v))
return args
}
} else {
args = append(args, col.Value)
case types.VectorFloat32:
col.Value = v.String()
}
args = append(args, col.Value)

return args
}
Expand Down
50 changes: 50 additions & 0 deletions cdc/sink/dmlsink/txn/mysql/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (

"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tiflow/cdc/model"
"github.com/pingcap/tiflow/pkg/util"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -248,6 +250,37 @@ func TestPrepareUpdate(t *testing.T) {
expectedSQL: "UPDATE `test`.`t1` SET `a` = ?, `b` = ? WHERE `a` = ? AND `b` = ? LIMIT 1",
expectedArgs: []interface{}{2, "世界", 1, "你好"},
},
{
quoteTable: "`test`.`t1`",
preCols: []*model.Column{
{
Name: "a",
Type: mysql.TypeLong,
Flag: model.MultipleKeyFlag | model.HandleKeyFlag,
Value: 1,
},
{
Name: "b",
Type: mysql.TypeTiDBVectorFloat32,
Value: util.Must(types.ParseVectorFloat32("[1.0,-2,0.33,-4.4,55]")),
},
},
cols: []*model.Column{
{
Name: "a",
Type: mysql.TypeLong,
Flag: model.MultipleKeyFlag | model.HandleKeyFlag,
Value: 1,
},
{
Name: "b",
Type: mysql.TypeTiDBVectorFloat32,
Value: util.Must(types.ParseVectorFloat32("[1,2,3,4,5]")),
},
},
expectedSQL: "UPDATE `test`.`t1` SET `a` = ?, `b` = ? WHERE `a` = ? LIMIT 1",
expectedArgs: []interface{}{1, "[1,2,3,4,5]", 1},
},
}
for _, tc := range testCases {
query, args := prepareUpdate(tc.quoteTable, tc.preCols, tc.cols, false)
Expand Down Expand Up @@ -709,6 +742,23 @@ func TestMapReplace(t *testing.T) {
[]byte("你好,世界"),
},
},
{
quoteTable: "`test`.`t1`",
cols: []*model.Column{
{
Name: "a",
Type: mysql.TypeTiDBVectorFloat32,
Value: util.Must(types.ParseVectorFloat32("[1.0,-2,0.3,-4.4,55]")),
},
{
Name: "b",
Type: mysql.TypeTiDBVectorFloat32,
Value: util.Must(types.ParseVectorFloat32("[1,2,3,4,5]")),
},
},
expectedQuery: "REPLACE INTO `test`.`t1` (`a`,`b`) VALUES ",
expectedArgs: []interface{}{"[1,-2,0.3,-4.4,55]", "[1,2,3,4,5]"},
},
}
for _, tc := range testCases {
// multiple times to verify the stability of column sequence in query string
Expand Down
19 changes: 11 additions & 8 deletions cdc/sink/dmlsink/txn/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tiflow/cdc/model"
"github.com/pingcap/tiflow/cdc/sink/dmlsink"
"github.com/pingcap/tiflow/cdc/sink/metrics"
Expand Down Expand Up @@ -335,17 +336,19 @@ func convert2RowChanges(
return res
}

func convertBinaryToString(cols []*model.ColumnData, tableInfo *model.TableInfo) {
func convertValue(cols []*model.ColumnData, tableInfo *model.TableInfo) {
for i, col := range cols {
if col == nil {
continue
}
colInfo := tableInfo.ForceGetColumnInfo(col.ColumnID)
if colInfo.GetCharset() != "" && colInfo.GetCharset() != charset.CharsetBin {
colValBytes, ok := col.Value.([]byte)
if ok {
cols[i].Value = string(colValBytes)
switch v := col.Value.(type) {
case []byte:
colInfo := tableInfo.ForceGetColumnInfo(col.ColumnID)
if colInfo.GetCharset() != "" && colInfo.GetCharset() != charset.CharsetBin {
cols[i].Value = string(v)
}
case types.VectorFloat32:
cols[i].Value = v.String()
}
}
}
Expand All @@ -364,8 +367,8 @@ func (s *mysqlBackend) groupRowsByType(
deleteRow := make([]*sqlmodel.RowChange, 0, preAllocateSize)

for _, row := range event.Event.Rows {
convertBinaryToString(row.Columns, tableInfo)
convertBinaryToString(row.PreColumns, tableInfo)
convertValue(row.Columns, tableInfo)
convertValue(row.PreColumns, tableInfo)

if row.IsInsert() {
insertRow = append(
Expand Down
Loading

0 comments on commit 021fd64

Please sign in to comment.