From 97a1fa5127afca4752247da72676155ee83ce69d Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Jan 2025 09:44:00 +0200 Subject: [PATCH] chore: extract resolver --- db.go | 169 +++++++++++++++++++++++++++++++++++--------------- query_base.go | 19 ++---- 2 files changed, 122 insertions(+), 66 deletions(-) diff --git a/db.go b/db.go index abeb7aa97..ddd922eee 100644 --- a/db.go +++ b/db.go @@ -33,14 +33,14 @@ func WithDiscardUnknownColumns() DBOption { } } -func WithReadOnlyReplica(replica *sql.DB) DBOption { +func WithConnResolver(resolver ConnResolver) DBOption { return func(db *DB) { - db.replicas = append(db.replicas, replica) + db.resolver = resolver } } type DB struct { - // Must be a pointer so we copy the state, not the state fields. + // Must be a pointer so we copy the whole state, not individual fields. *noCopyState queryHooks []QueryHook @@ -53,11 +53,8 @@ type DB struct { // for example, it is forbidden to copy atomic.Pointer. type noCopyState struct { *sql.DB - dialect schema.Dialect - - replicas []*sql.DB - healthyReplicas atomic.Pointer[[]*sql.DB] - nextReplica atomic.Int64 + dialect schema.Dialect + resolver ConnResolver flags internal.Flag closed atomic.Bool @@ -78,10 +75,6 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { opt(db) } - if len(db.replicas) > 0 { - go db.monitorReplicas() - } - return db } @@ -95,7 +88,16 @@ func (db *DB) String() string { func (db *DB) Close() error { db.closed.Store(true) - return db.DB.Close() + + firstErr := db.DB.Close() + + if db.resolver != nil { + if err := db.resolver.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr } func (db *DB) DBStats() DBStats { @@ -261,44 +263,6 @@ func (db *DB) HasFeature(feat feature.Feature) bool { return db.dialect.Features().Has(feat) } -// healthyReplica returns a random healthy replica. -func (db *DB) healthyReplica() *sql.DB { - replicas := db.loadHealthyReplicas() - if len(replicas) == 0 { - return db.DB - } - if len(replicas) == 1 { - return replicas[0] - } - i := db.nextReplica.Add(1) - return replicas[int(i)%len(replicas)] -} - -func (db *DB) loadHealthyReplicas() []*sql.DB { - if ptr := db.healthyReplicas.Load(); ptr != nil { - return *ptr - } - return nil -} - -func (db *DB) monitorReplicas() { - for !db.closed.Load() { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - healthy := make([]*sql.DB, 0, len(db.replicas)) - - for _, replica := range db.replicas { - if err := replica.PingContext(ctx); err == nil { - healthy = append(healthy, replica) - } - } - - db.healthyReplicas.Store(&healthy) - time.Sleep(5 * time.Second) - } -} - //------------------------------------------------------------------------------ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { @@ -770,3 +734,106 @@ func (tx Tx) NewDropColumn() *DropColumnQuery { func (db *DB) makeQueryBytes() []byte { return internal.MakeQueryBytes() } + +//------------------------------------------------------------------------------ + +type ConnResolver interface { + ResolveConn(query Query) IConn + Close() error +} + +type ReadWriteConnResolver struct { + replicas []*sql.DB // read-only replicas + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + closed atomic.Bool +} + +func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver { + r := new(ReadWriteConnResolver) + + for _, opt := range opts { + opt(r) + } + + if len(r.replicas) > 0 { + go r.monitor() + } + return r +} + +type ReadWriteConnResolverOption func(r *ReadWriteConnResolver) + +func WithReadOnlyReplica(db *sql.DB) ReadWriteConnResolverOption { + return func(r *ReadWriteConnResolver) { + r.replicas = append(r.replicas, db) + } +} + +func (r *ReadWriteConnResolver) Close() error { + r.closed.Store(true) + + var firstErr error + for _, db := range r.replicas { + if err := db.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// healthyReplica returns a random healthy replica. +func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn { + if len(r.replicas) == 0 || !isReadOnlyQuery(query) { + return nil + } + + replicas := r.loadHealthyReplicas() + if len(replicas) == 0 { + return nil + } + if len(replicas) == 1 { + return replicas[0] + } + i := r.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true +} + +func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB { + if ptr := r.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (r *ReadWriteConnResolver) monitor() { + const interval = 5 * time.Second + for !r.closed.Load() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + healthy := make([]*sql.DB, 0, len(r.replicas)) + + for _, replica := range r.replicas { + if err := replica.PingContext(ctx); err == nil { + healthy = append(healthy, replica) + } + } + + r.healthyReplicas.Store(&healthy) + time.Sleep(interval) + } +} diff --git a/query_base.go b/query_base.go index 27f557426..b17498742 100644 --- a/query_base.go +++ b/query_base.go @@ -118,23 +118,12 @@ func (q *baseQuery) resolveConn(query Query) IConn { if q.conn != nil { return q.conn } - if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) { - return q.db.DB - } - return q.db.healthyReplica() -} - -func isReadOnlyQuery(query Query) bool { - sel, ok := query.(*SelectQuery) - if !ok { - return false - } - for _, el := range sel.with { - if !isReadOnlyQuery(el.query) { - return false + if q.db.resolver != nil { + if conn := q.db.resolver.ResolveConn(query); conn != nil { + return conn } } - return true + return q.db.DB } func (q *baseQuery) GetModel() Model {