Skip to content

Commit

Permalink
chore: extract resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jan 18, 2025
1 parent 702e525 commit 97a1fa5
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 66 deletions.
169 changes: 118 additions & 51 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ func WithDiscardUnknownColumns() DBOption {
}
}

func WithReadOnlyReplica(replica *sql.DB) DBOption {
func WithConnResolver(resolver ConnResolver) DBOption {
return func(db *DB) {
db.replicas = append(db.replicas, replica)
db.resolver = resolver
}
}

type DB struct {
// Must be a pointer so we copy the state, not the state fields.
// Must be a pointer so we copy the whole state, not individual fields.
*noCopyState

queryHooks []QueryHook
Expand All @@ -53,11 +53,8 @@ type DB struct {
// for example, it is forbidden to copy atomic.Pointer.
type noCopyState struct {
*sql.DB
dialect schema.Dialect

replicas []*sql.DB
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64
dialect schema.Dialect
resolver ConnResolver

flags internal.Flag
closed atomic.Bool
Expand All @@ -78,10 +75,6 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
opt(db)
}

if len(db.replicas) > 0 {
go db.monitorReplicas()
}

return db
}

Expand All @@ -95,7 +88,16 @@ func (db *DB) String() string {

func (db *DB) Close() error {
db.closed.Store(true)
return db.DB.Close()

firstErr := db.DB.Close()

if db.resolver != nil {
if err := db.resolver.Close(); err != nil && firstErr == nil {
firstErr = err
}
}

return firstErr
}

func (db *DB) DBStats() DBStats {
Expand Down Expand Up @@ -261,44 +263,6 @@ func (db *DB) HasFeature(feat feature.Feature) bool {
return db.dialect.Features().Has(feat)
}

// healthyReplica returns a random healthy replica.
func (db *DB) healthyReplica() *sql.DB {
replicas := db.loadHealthyReplicas()
if len(replicas) == 0 {
return db.DB
}
if len(replicas) == 1 {
return replicas[0]
}
i := db.nextReplica.Add(1)
return replicas[int(i)%len(replicas)]
}

func (db *DB) loadHealthyReplicas() []*sql.DB {
if ptr := db.healthyReplicas.Load(); ptr != nil {
return *ptr
}
return nil
}

func (db *DB) monitorReplicas() {
for !db.closed.Load() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

healthy := make([]*sql.DB, 0, len(db.replicas))

for _, replica := range db.replicas {
if err := replica.PingContext(ctx); err == nil {
healthy = append(healthy, replica)
}
}

db.healthyReplicas.Store(&healthy)
time.Sleep(5 * time.Second)
}
}

//------------------------------------------------------------------------------

func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
Expand Down Expand Up @@ -770,3 +734,106 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
func (db *DB) makeQueryBytes() []byte {
return internal.MakeQueryBytes()
}

//------------------------------------------------------------------------------

type ConnResolver interface {
ResolveConn(query Query) IConn
Close() error
}

type ReadWriteConnResolver struct {
replicas []*sql.DB // read-only replicas
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64
closed atomic.Bool
}

func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver {
r := new(ReadWriteConnResolver)

for _, opt := range opts {
opt(r)
}

if len(r.replicas) > 0 {
go r.monitor()
}
return r
}

type ReadWriteConnResolverOption func(r *ReadWriteConnResolver)

func WithReadOnlyReplica(db *sql.DB) ReadWriteConnResolverOption {
return func(r *ReadWriteConnResolver) {
r.replicas = append(r.replicas, db)
}
}

func (r *ReadWriteConnResolver) Close() error {
r.closed.Store(true)

var firstErr error
for _, db := range r.replicas {
if err := db.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}

// healthyReplica returns a random healthy replica.
func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn {
if len(r.replicas) == 0 || !isReadOnlyQuery(query) {
return nil
}

replicas := r.loadHealthyReplicas()
if len(replicas) == 0 {
return nil
}
if len(replicas) == 1 {
return replicas[0]
}
i := r.nextReplica.Add(1)
return replicas[int(i)%len(replicas)]
}

func isReadOnlyQuery(query Query) bool {
sel, ok := query.(*SelectQuery)
if !ok {
return false
}
for _, el := range sel.with {
if !isReadOnlyQuery(el.query) {
return false
}
}
return true
}

func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB {
if ptr := r.healthyReplicas.Load(); ptr != nil {
return *ptr
}
return nil
}

func (r *ReadWriteConnResolver) monitor() {
const interval = 5 * time.Second
for !r.closed.Load() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

healthy := make([]*sql.DB, 0, len(r.replicas))

for _, replica := range r.replicas {
if err := replica.PingContext(ctx); err == nil {
healthy = append(healthy, replica)
}
}

r.healthyReplicas.Store(&healthy)
time.Sleep(interval)
}
}
19 changes: 4 additions & 15 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,12 @@ func (q *baseQuery) resolveConn(query Query) IConn {
if q.conn != nil {
return q.conn
}
if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) {
return q.db.DB
}
return q.db.healthyReplica()
}

func isReadOnlyQuery(query Query) bool {
sel, ok := query.(*SelectQuery)
if !ok {
return false
}
for _, el := range sel.with {
if !isReadOnlyQuery(el.query) {
return false
if q.db.resolver != nil {
if conn := q.db.resolver.ResolveConn(query); conn != nil {
return conn
}
}
return true
return q.db.DB
}

func (q *baseQuery) GetModel() Model {
Expand Down

0 comments on commit 97a1fa5

Please sign in to comment.