diff --git a/conn_pool.go b/conn_pool.go index c83adb6..32c8f59 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -36,7 +36,7 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) pool.sharding.querys.Store("last_query", stQuery) if table != "" { - if r, ok := pool.sharding.configs[table]; ok { + if r, ok := pool.sharding.Configs[table]; ok { if r.DoubleWrite { pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...) diff --git a/dialector.go b/dialector.go index 455ab4c..383bba1 100644 --- a/dialector.go +++ b/dialector.go @@ -107,7 +107,7 @@ func (m ShardingMigrator) splitShardingDsts(dsts ...any) (shardingDsts []shardin return } - if cfg, ok := m.sharding.configs[stmt.Table]; ok { + if cfg, ok := m.sharding.Configs[stmt.Table]; ok { // support sharding table suffixs := cfg.ShardingSuffixs() if len(suffixs) == 0 { diff --git a/sharding.go b/sharding.go index 85b2738..1d5c3c2 100644 --- a/sharding.go +++ b/sharding.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "hash/crc32" + "log/slog" "strconv" "strings" "sync" @@ -27,7 +28,7 @@ var ( type Sharding struct { *gorm.DB ConnPool *ConnPool - configs map[string]Config + Configs map[string]Config querys sync.Map snowflakeNodes []*snowflake.Node @@ -111,23 +112,23 @@ func Register(config Config, tables ...any) *Sharding { } func (s *Sharding) compile() error { - if s.configs == nil { - s.configs = make(map[string]Config) + if s.Configs == nil { + s.Configs = make(map[string]Config) } for _, table := range s._tables { if t, ok := table.(string); ok { - s.configs[t] = s._config + s.Configs[t] = s._config } else { stmt := &gorm.Statement{DB: s.DB} if err := stmt.Parse(table); err == nil { - s.configs[stmt.Table] = s._config + s.Configs[stmt.Table] = s._config } else { return err } } } - for t, c := range s.configs { + for t, c := range s.Configs { if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake { panic("Snowflake NumberOfShards should less than 1024") } @@ -216,7 +217,7 @@ func (s *Sharding) compile() error { } } } - s.configs[t] = c + s.Configs[t] = c } return nil @@ -242,7 +243,7 @@ func (s *Sharding) Initialize(db *gorm.DB) error { s.DB = db s.registerCallbacks(db) - for t, c := range s.configs { + for t, c := range s.Configs { if c.PrimaryKeyGenerator == PKPGSequence { err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error if err != nil { @@ -295,12 +296,39 @@ func (s *Sharding) switchConn(db *gorm.DB) { s.mutex.Unlock() } } +func replaceConditionTableName(oldTableName, tableName string, expr sqlparser.Expr) error { + + err := sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { + if n, ok := node.(*sqlparser.BinaryExpr); ok { + if q, ok2 := n.X.(*sqlparser.QualifiedRef); ok2 { + if q.Table.Name == oldTableName { + n.X.(*sqlparser.QualifiedRef).Table.Name = tableName + } + } + } + return nil + }), expr) + return err +} +func replaceSelectFieldTableName(oldTableName, tableName string, columns *sqlparser.OutputNames) *sqlparser.OutputNames { + rcs := []*sqlparser.ResultColumn(*columns) + for i := 0; i < len(rcs); i++ { + rc := rcs[i] + if n, ok := rc.Expr.(*sqlparser.QualifiedRef); ok { + if n.Table.Name == oldTableName { + n.Table.Name = tableName + } + } + } + r := sqlparser.OutputNames(rcs) + return &r +} // resolve split the old query to full table query and sharding table query func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) { ftQuery = query stQuery = query - if len(s.configs) == 0 { + if len(s.Configs) == 0 { return } @@ -344,7 +372,7 @@ func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableNa } tableName = table.Name.Name - r, ok := s.configs[tableName] + r, ok := s.Configs[tableName] if !ok { return } @@ -416,6 +444,7 @@ func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableNa ftQuery = insertStmt.String() insertStmt.TableName = newTable stQuery = insertStmt.String() + slog.Debug(stQuery) } else { var value any @@ -438,16 +467,21 @@ func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableNa ftQuery = stmt.String() stmt.FromItems = newTable stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) + replaceConditionTableName(tableName, newTable.Name.Name, stmt.Condition) + stmt.Columns = replaceSelectFieldTableName(tableName, newTable.Name.Name, stmt.Columns) stQuery = stmt.String() case *sqlparser.UpdateStatement: ftQuery = stmt.String() stmt.TableName = newTable + replaceConditionTableName(tableName, newTable.Name.Name, stmt.Condition) stQuery = stmt.String() case *sqlparser.DeleteStatement: ftQuery = stmt.String() stmt.TableName = newTable + replaceConditionTableName(tableName, newTable.Name.Name, stmt.Condition) stQuery = stmt.String() } + slog.Debug(stQuery) } return @@ -500,7 +534,14 @@ func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sql func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) { err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { if n, ok := node.(*sqlparser.BinaryExpr); ok { - if x, ok := n.X.(*sqlparser.Ident); ok { + x, ok := n.X.(*sqlparser.Ident) + if !ok { + if q, ok2 := n.X.(*sqlparser.QualifiedRef); ok2 { + x = q.Column + ok = true + } + } + if ok { if x.Name == key && n.Op == sqlparser.EQ { keyFind = true switch expr := n.Y.(type) {