Skip to content

Commit

Permalink
Merge pull request #1085 from uptrace/feat/read-only-replica
Browse files Browse the repository at this point in the history
feat: allow to specify read-only replica for SELECTs
  • Loading branch information
vmihailenco authored Jan 22, 2025
2 parents dbae5e6 + a95c99a commit 3d8666a
Show file tree
Hide file tree
Showing 19 changed files with 300 additions and 124 deletions.
170 changes: 162 additions & 8 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strings"
"sync/atomic"
"time"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
Expand All @@ -26,32 +27,56 @@ type DBStats struct {

type DBOption func(db *DB)

func WithOptions(opts ...DBOption) DBOption {
return func(db *DB) {
for _, opt := range opts {
opt(db)
}
}
}

func WithDiscardUnknownColumns() DBOption {
return func(db *DB) {
db.flags = db.flags.Set(discardUnknownColumns)
}
}

type DB struct {
*sql.DB
func WithConnResolver(resolver ConnResolver) DBOption {
return func(db *DB) {
db.resolver = resolver
}
}

dialect schema.Dialect
type DB struct {
// Must be a pointer so we copy the whole state, not individual fields.
*noCopyState

queryHooks []QueryHook

fmter schema.Formatter
flags internal.Flag

stats DBStats
}

// noCopyState contains DB fields that must not be copied on clone(),
// for example, it is forbidden to copy atomic.Pointer.
type noCopyState struct {
*sql.DB
dialect schema.Dialect
resolver ConnResolver

flags internal.Flag
closed atomic.Bool
}

func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
noCopyState: &noCopyState{
DB: sqldb,
dialect: dialect,
},
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
Expand All @@ -69,6 +94,22 @@ func (db *DB) String() string {
return b.String()
}

func (db *DB) Close() error {
if db.closed.Swap(true) {
return nil
}

firstErr := db.DB.Close()

if db.resolver != nil {
if err := db.resolver.Close(); err != nil && firstErr == nil {
firstErr = err
}
}

return firstErr
}

func (db *DB) DBStats() DBStats {
return DBStats{
Queries: atomic.LoadUint32(&db.stats.Queries),
Expand Down Expand Up @@ -703,3 +744,116 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
func (db *DB) makeQueryBytes() []byte {
return internal.MakeQueryBytes()
}

//------------------------------------------------------------------------------

// ConnResolver enables routing queries to multiple databases.
type ConnResolver interface {
ResolveConn(query Query) IConn
Close() error
}

// TODO:
// - make monitoring interval configurable
// - make ping timeout configutable
// - allow adding read/write replicas for multi-master replication
type ReadWriteConnResolver struct {
replicas []*sql.DB // read-only replicas
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64
closed atomic.Bool
}

func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver {
r := new(ReadWriteConnResolver)

for _, opt := range opts {
opt(r)
}

if len(r.replicas) > 0 {
r.healthyReplicas.Store(&r.replicas)
go r.monitor()
}

return r
}

type ReadWriteConnResolverOption func(r *ReadWriteConnResolver)

func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption {
return func(r *ReadWriteConnResolver) {
r.replicas = append(r.replicas, dbs...)
}
}

func (r *ReadWriteConnResolver) Close() error {
if r.closed.Swap(true) {
return nil
}

var firstErr error
for _, db := range r.replicas {
if err := db.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}

// healthyReplica returns a random healthy replica.
func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn {
if len(r.replicas) == 0 || !isReadOnlyQuery(query) {
return nil
}

replicas := r.loadHealthyReplicas()
if len(replicas) == 0 {
return nil
}
if len(replicas) == 1 {
return replicas[0]
}
i := r.nextReplica.Add(1)
return replicas[int(i)%len(replicas)]
}

func isReadOnlyQuery(query Query) bool {
sel, ok := query.(*SelectQuery)
if !ok {
return false
}
for _, el := range sel.with {
if !isReadOnlyQuery(el.query) {
return false
}
}
return true
}

func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB {
if ptr := r.healthyReplicas.Load(); ptr != nil {
return *ptr
}
return nil
}

func (r *ReadWriteConnResolver) monitor() {
const interval = 5 * time.Second
for !r.closed.Load() {
healthy := make([]*sql.DB, 0, len(r.replicas))

for _, replica := range r.replicas {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
err := replica.PingContext(ctx)
cancel()

if err == nil {
healthy = append(healthy, replica)
}
}

r.healthyReplicas.Store(&healthy)
time.Sleep(interval)
}
}
Loading

0 comments on commit 3d8666a

Please sign in to comment.