Skip to content

Commit

Permalink
单表支持多分表策略优化1:分表策略的key拼接封装到组件内,不对用户开放2.上下文中sharding_key改为const3.修复一个panic
Browse files Browse the repository at this point in the history
  • Loading branch information
pandreame committed Oct 31, 2024
1 parent 6fd67a8 commit c5a8a24
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 33 deletions.
14 changes: 9 additions & 5 deletions conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sharding
import (
"context"
"database/sql"
"fmt"
"time"

"gorm.io/gorm"
Expand Down Expand Up @@ -38,14 +37,17 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any)

if table != "" {
key := table
if shardingKey, ok := ctx.Value("sharding_key").(string); ok {
key = fmt.Sprintf("%s_%v", table, shardingKey)
key, err = pool.sharding.getConfigKey(ctx, table)
if err != nil {
return nil, err
}
if r, ok := pool.sharding.configs[key]; ok {
if r.DoubleWrite {
pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) {
result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...)
rowsAffected, _ = result.RowsAffected()
if result != nil {
rowsAffected, _ = result.RowsAffected()
}
return pool.sharding.Explain(ftQuery, args...), rowsAffected
}, pool.sharding.Error)
}
Expand All @@ -55,7 +57,9 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any)
var result sql.Result
result, err = pool.ConnPool.ExecContext(ctx, stQuery, args...)
pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) {
rowsAffected, _ = result.RowsAffected()
if result != nil {
rowsAffected, _ = result.RowsAffected()
}
return pool.sharding.Explain(stQuery, args...), rowsAffected
}, pool.sharding.Error)

Expand Down
68 changes: 60 additions & 8 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@ import (
)

var (
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
ErrInvalidID = errors.New("invalid id format")
ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ")
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
ErrInvalidID = errors.New("invalid id format")
ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ")
ErrShardingKeyNotExistInContext = errors.New("the value passed in the context is not the sharding key")
ErrMissingTableName = errors.New("table name is required")
)

var (
ShardingIgnoreStoreKey = "sharding_ignore"
)

// ContextKeyForShardingKey is the context key for sharding key.
const ContextKeyForShardingKey = "sharding_key"

type Sharding struct {
*gorm.DB
ConnPool *ConnPool
Expand All @@ -47,6 +52,10 @@ type Config struct {
// For example, for a product order table, you may want to split the rows by `user_id`.
ShardingKey string

// logical table name.Suport multiple table names with same sharding key.
// For example, for user and order table, you may want to shard by `user_id`.
TableNames []string

// NumberOfShards specifies how many tables you want to sharding.
NumberOfShards uint

Expand Down Expand Up @@ -112,10 +121,53 @@ func Register(config Config, tables ...any) *Sharding {
}

// enables sharding for a single table with flexible support for multiple partition keys.
func RegisterWithKeys(configs map[string]Config) *Sharding {
func RegisterWithKeys(configs []Config) (*Sharding, error) {
mapConfig := make(map[string]Config, len(configs))
for _, config := range configs {
for _, tableName := range config.TableNames {
configKey, err := generateConfigsKey(tableName, config.ShardingKey)
if err != nil {
return nil, err
}
mapConfig[configKey] = config
}
}
return &Sharding{
configs: configs,
configs: mapConfig,
}, nil
}

// generates the key for the sharding config.
func generateConfigsKey(tableName, shardingKey string) (string, error) {
// Table name cannot be empty
if tableName == "" {
return "", ErrMissingTableName
}
if shardingKey == "" {
return "", ErrMissingShardingKey
}
return fmt.Sprintf("%s_%s", tableName, shardingKey), nil
}

// get the configs key for using it to get the sharding config.
func (s *Sharding) getConfigKey(ctx context.Context, tableName string) (string, error) {
configKey := tableName
if shardingKey, ok := ctx.Value(ContextKeyForShardingKey).(string); ok {
// If sharding key is set in context, use it to get the sharding config.
configKey = fmt.Sprintf("%s_%s", tableName, shardingKey)
} else {
// If sharding key is not set in context, use the table name as the key.
return configKey, nil
}

// check if the sharding key exists in the configs.
_, exis := s.configs[configKey]
if !exis {
return "", ErrShardingKeyNotExistInContext
}

// If sharding key is not set in context, use the table name as the key.
return configKey, nil
}

func (s *Sharding) compile() error {
Expand Down Expand Up @@ -353,9 +405,9 @@ func (s *Sharding) resolve(ctx context.Context, query string, args ...any) (ftQu

tableName = table.Name.Name
key := tableName
// If sharding key is set in context, use it to get the sharding config.
if shardingKey, ok := ctx.Value("sharding_key").(string); ok {
key = fmt.Sprintf("%s_%v", tableName, shardingKey)
key, err = s.getConfigKey(ctx, tableName)
if err != nil {
return
}
r, ok := s.configs[key]
if !ok {
Expand Down
47 changes: 27 additions & 20 deletions test/sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,32 +154,39 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) {

// Configure Gorm Sharding middleware, register sharding strategy configuration
// Logical table name is "orders"
db.Use(sharding.RegisterWithKeys(map[string]sharding.Config{
"orders_order_year": {
ShardingKey: "order_year",
shardingConfig, err := sharding.RegisterWithKeys([]sharding.Config{
{
ShardingKey: "order_year",
TableNames: []string{"orders"},
// Use custom sharding algorithm
ShardingAlgorithm: customShardingAlgorithmWithOrderYear,
ShardingAlgorithm: customShardingAlgorithmWithOrderYear,
// Use custom primary key generation function
PrimaryKeyGenerator: sharding.PKCustom,
PrimaryKeyGenerator: sharding.PKCustom,
// Custom primary key generation function
PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn,
},
"orders_user_id": {
ShardingKey: "user_id",
NumberOfShards: 4,
{
ShardingKey: "user_id",
TableNames: []string{"orders"},
NumberOfShards: 4,
// Use custom sharding algorithm
ShardingAlgorithm: customShardingAlgorithmWithUserId,
ShardingAlgorithm: customShardingAlgorithmWithUserId,
// Use Snowflake algorithm to generate primary key
PrimaryKeyGenerator: sharding.PKSnowflake,
},
"orders_order_id": {
ShardingKey: "order_id",
{
ShardingKey: "order_id",
TableNames: []string{"orders"},
// Use custom sharding algorithm
ShardingAlgorithm: customShardingAlgorithmWithOrderId,
PrimaryKeyGenerator: sharding.PKCustom,
PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn,
},
}))
})
if err != nil {
panic(err)
}
db.Use(shardingConfig)

// Insert and query examples based on order_year sharding key strategy
InsertOrderByOrderYearKey(db)
Expand All @@ -199,7 +206,7 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) {
func InsertOrderByOrderYearKey(db *gorm.DB) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "order_year")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year")
db = db.WithContext(ctx)
// Randomly 2024 or 2025
orderYear := rand.Intn(2) + 2024
Expand All @@ -225,7 +232,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) {
var orders []Order
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "order_year")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year")
db = db.WithContext(ctx)
db = db.Table("orders")
err := db.Model(&Order{}).Where("order_year=? and product_id=? and order_id=?", orderYear, 102, "20240101ORDER0002").Find(&orders).Error
Expand All @@ -239,7 +246,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) {
func InsertOrderByOrderIdKey(db *gorm.DB) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "order_id")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id")
db = db.WithContext(ctx)
// Randomly 2024 or 2025
orderYear := rand.Intn(2) + 2024
Expand All @@ -265,7 +272,7 @@ func InsertOrderByOrderIdKey(db *gorm.DB) error {
func UpdateByOrderIdKey(db *gorm.DB, orderId string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "order_id")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id")
db = db.WithContext(ctx)
db = db.Table("orders")
err := db.Model(&Order{}).Where("order_id=?", orderId).Update("product_id", 102).Error
Expand All @@ -278,7 +285,7 @@ func UpdateByOrderIdKey(db *gorm.DB, orderId string) error {
func DeleteByOrderIdKey(db *gorm.DB, orderId string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "order_id")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id")
db = db.WithContext(ctx)
db = db.Table("orders")
err := db.Where("order_id=? and product_id=?", orderId, 100).Delete(&Order{}).Error
Expand All @@ -292,7 +299,7 @@ func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) {
// Query example
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "order_id")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id")
db = db.WithContext(ctx)
db = db.Table("orders")
err := db.Model(&Order{}).Where("order_id=?", orderId).Find(&orders).Error
Expand All @@ -314,7 +321,7 @@ type OrderByUserId struct {
func InsertOrderByUserId(db *gorm.DB) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "user_id")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id")
db = db.WithContext(ctx)
// Randomly 2024 or 2025
orderYear := rand.Intn(2) + 2024
Expand All @@ -340,7 +347,7 @@ func FindByUserIDKey(db *gorm.DB, userID int64) ([]Order, error) {
// Query example
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "sharding_key", "user_id")
ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id")
db = db.WithContext(ctx)
db = db.Table("orders")
err := db.Model(&Order{}).Where("user_id = ?", userID).Find(&orders).Error
Expand Down

0 comments on commit c5a8a24

Please sign in to comment.