diff --git a/conn_pool.go b/conn_pool.go index c0b8e85..ed80563 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -3,7 +3,6 @@ package sharding import ( "context" "database/sql" - "fmt" "time" "gorm.io/gorm" @@ -38,14 +37,17 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) if table != "" { key := table - if shardingKey, ok := ctx.Value("sharding_key").(string); ok { - key = fmt.Sprintf("%s_%v", table, shardingKey) + key, err = pool.sharding.getConfigKey(ctx, table) + if err != nil { + return nil, err } if r, ok := pool.sharding.configs[key]; ok { if r.DoubleWrite { pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...) - rowsAffected, _ = result.RowsAffected() + if result != nil { + rowsAffected, _ = result.RowsAffected() + } return pool.sharding.Explain(ftQuery, args...), rowsAffected }, pool.sharding.Error) } @@ -55,7 +57,9 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) var result sql.Result result, err = pool.ConnPool.ExecContext(ctx, stQuery, args...) pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { - rowsAffected, _ = result.RowsAffected() + if result != nil { + rowsAffected, _ = result.RowsAffected() + } return pool.sharding.Explain(stQuery, args...), rowsAffected }, pool.sharding.Error) diff --git a/sharding.go b/sharding.go index ab37625..bbef9ee 100644 --- a/sharding.go +++ b/sharding.go @@ -16,15 +16,20 @@ import ( ) var ( - ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =") - ErrInvalidID = errors.New("invalid id format") - ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ") + ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =") + ErrInvalidID = errors.New("invalid id format") + ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ") + ErrShardingKeyNotExistInContext = errors.New("the value passed in the context is not the sharding key") + ErrMissingTableName = errors.New("table name is required") ) var ( ShardingIgnoreStoreKey = "sharding_ignore" ) +// ContextKeyForShardingKey is the context key for sharding key. +const ContextKeyForShardingKey = "sharding_key" + type Sharding struct { *gorm.DB ConnPool *ConnPool @@ -47,6 +52,10 @@ type Config struct { // For example, for a product order table, you may want to split the rows by `user_id`. ShardingKey string + // logical table name.Suport multiple table names with same sharding key. + // For example, for user and order table, you may want to shard by `user_id`. + TableNames []string + // NumberOfShards specifies how many tables you want to sharding. NumberOfShards uint @@ -112,10 +121,53 @@ func Register(config Config, tables ...any) *Sharding { } // enables sharding for a single table with flexible support for multiple partition keys. -func RegisterWithKeys(configs map[string]Config) *Sharding { +func RegisterWithKeys(configs []Config) (*Sharding, error) { + mapConfig := make(map[string]Config, len(configs)) + for _, config := range configs { + for _, tableName := range config.TableNames { + configKey, err := generateConfigsKey(tableName, config.ShardingKey) + if err != nil { + return nil, err + } + mapConfig[configKey] = config + } + } return &Sharding{ - configs: configs, + configs: mapConfig, + }, nil +} + +// generates the key for the sharding config. +func generateConfigsKey(tableName, shardingKey string) (string, error) { + // Table name cannot be empty + if tableName == "" { + return "", ErrMissingTableName + } + if shardingKey == "" { + return "", ErrMissingShardingKey } + return fmt.Sprintf("%s_%s", tableName, shardingKey), nil +} + +// get the configs key for using it to get the sharding config. +func (s *Sharding) getConfigKey(ctx context.Context, tableName string) (string, error) { + configKey := tableName + if shardingKey, ok := ctx.Value(ContextKeyForShardingKey).(string); ok { + // If sharding key is set in context, use it to get the sharding config. + configKey = fmt.Sprintf("%s_%s", tableName, shardingKey) + } else { + // If sharding key is not set in context, use the table name as the key. + return configKey, nil + } + + // check if the sharding key exists in the configs. + _, exis := s.configs[configKey] + if !exis { + return "", ErrShardingKeyNotExistInContext + } + + // If sharding key is not set in context, use the table name as the key. + return configKey, nil } func (s *Sharding) compile() error { @@ -353,9 +405,9 @@ func (s *Sharding) resolve(ctx context.Context, query string, args ...any) (ftQu tableName = table.Name.Name key := tableName - // If sharding key is set in context, use it to get the sharding config. - if shardingKey, ok := ctx.Value("sharding_key").(string); ok { - key = fmt.Sprintf("%s_%v", tableName, shardingKey) + key, err = s.getConfigKey(ctx, tableName) + if err != nil { + return } r, ok := s.configs[key] if !ok { diff --git a/test/sharding_test.go b/test/sharding_test.go index 4617c4a..53d09eb 100644 --- a/test/sharding_test.go +++ b/test/sharding_test.go @@ -154,32 +154,39 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { // Configure Gorm Sharding middleware, register sharding strategy configuration // Logical table name is "orders" - db.Use(sharding.RegisterWithKeys(map[string]sharding.Config{ - "orders_order_year": { - ShardingKey: "order_year", + shardingConfig, err := sharding.RegisterWithKeys([]sharding.Config{ + { + ShardingKey: "order_year", + TableNames: []string{"orders"}, // Use custom sharding algorithm - ShardingAlgorithm: customShardingAlgorithmWithOrderYear, + ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // Use custom primary key generation function - PrimaryKeyGenerator: sharding.PKCustom, + PrimaryKeyGenerator: sharding.PKCustom, // Custom primary key generation function PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, }, - "orders_user_id": { - ShardingKey: "user_id", - NumberOfShards: 4, + { + ShardingKey: "user_id", + TableNames: []string{"orders"}, + NumberOfShards: 4, // Use custom sharding algorithm - ShardingAlgorithm: customShardingAlgorithmWithUserId, + ShardingAlgorithm: customShardingAlgorithmWithUserId, // Use Snowflake algorithm to generate primary key PrimaryKeyGenerator: sharding.PKSnowflake, }, - "orders_order_id": { - ShardingKey: "order_id", + { + ShardingKey: "order_id", + TableNames: []string{"orders"}, // Use custom sharding algorithm ShardingAlgorithm: customShardingAlgorithmWithOrderId, PrimaryKeyGenerator: sharding.PKCustom, PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, }, - })) + }) + if err != nil { + panic(err) + } + db.Use(shardingConfig) // Insert and query examples based on order_year sharding key strategy InsertOrderByOrderYearKey(db) @@ -199,7 +206,7 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { func InsertOrderByOrderYearKey(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_year") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -225,7 +232,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { var orders []Order ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_year") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_year=? and product_id=? and order_id=?", orderYear, 102, "20240101ORDER0002").Find(&orders).Error @@ -239,7 +246,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { func InsertOrderByOrderIdKey(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -265,7 +272,7 @@ func InsertOrderByOrderIdKey(db *gorm.DB) error { func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_id=?", orderId).Update("product_id", 102).Error @@ -278,7 +285,7 @@ func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { func DeleteByOrderIdKey(db *gorm.DB, orderId string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Where("order_id=? and product_id=?", orderId, 100).Delete(&Order{}).Error @@ -292,7 +299,7 @@ func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) { // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_id=?", orderId).Find(&orders).Error @@ -314,7 +321,7 @@ type OrderByUserId struct { func InsertOrderByUserId(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "user_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -340,7 +347,7 @@ func FindByUserIDKey(db *gorm.DB, userID int64) ([]Order, error) { // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "user_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("user_id = ?", userID).Find(&orders).Error