forked from Shopify/ghostferry
-
Notifications
You must be signed in to change notification settings - Fork 0
/
throttler.go
166 lines (137 loc) · 2.96 KB
/
throttler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package ghostferry
import (
"context"
sqlorig "database/sql"
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
type Throttler interface {
Throttled() bool
Disabled() bool
SetDisabled(bool)
SetPaused(bool)
Run(context.Context) error
}
func WaitForThrottle(t Throttler) {
if t.Disabled() || !t.Throttled() {
return
}
metrics.Measure("WaitForThrottle", nil, 1.0, func() {
for {
time.Sleep(500 * time.Millisecond)
if t.Disabled() || !t.Throttled() {
break
}
}
})
}
type ThrottlerBase struct {
disabled int32
}
func (t *ThrottlerBase) Disabled() bool {
return atomic.LoadInt32(&t.disabled) != 0
}
func (t *ThrottlerBase) SetDisabled(disabled bool) {
var val int32
if disabled {
val = 1
}
atomic.StoreInt32(&t.disabled, val)
}
type PauserThrottler struct {
ThrottlerBase
paused int32
}
func (t *PauserThrottler) Throttled() bool {
return atomic.LoadInt32(&t.paused) != 0
}
func (t *PauserThrottler) SetPaused(paused bool) {
var val int32
if paused {
val = 1
}
atomic.StoreInt32(&t.paused, val)
}
func (t *PauserThrottler) Run(ctx context.Context) error {
return nil
}
type LagThrottlerConfig struct {
Connection *DatabaseConfig
MaxLag int
Query string
UpdateInterval string
}
type LagThrottler struct {
ThrottlerBase
PauserThrottler
config *LagThrottlerConfig
DB *sql.DB
lag int
logger *logrus.Entry
interval time.Duration
}
func NewLagThrottler(config *LagThrottlerConfig) (*LagThrottler, error) {
if config.MaxLag <= 0 {
config.MaxLag = 1
}
if config.UpdateInterval == "" {
config.UpdateInterval = "1s"
}
if config.Query == "" {
return nil, fmt.Errorf("lag Query required")
}
interval, err := time.ParseDuration(config.UpdateInterval)
if err != nil {
return nil, fmt.Errorf("invalid UpdateInterval: %s", err)
}
if err := config.Connection.Validate(); err != nil {
return nil, fmt.Errorf("connection invalid: %s", err)
}
logger := logrus.WithField("tag", "throttler")
db, err := config.Connection.SqlDB(logger)
if err != nil {
return nil, fmt.Errorf("failed to create connection: %s", err)
}
return &LagThrottler{
config: config,
DB: db,
logger: logger,
interval: interval,
}, nil
}
func (t *LagThrottler) Throttled() bool {
return t.PauserThrottler.Throttled() || t.lag > t.config.MaxLag
}
func (t *LagThrottler) Run(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(t.interval):
}
err := WithRetriesContext(ctx, 5, t.interval, nil, "update lag", func() error {
return t.updateLag(ctx)
})
if err != nil {
return err
}
}
}
func (t *LagThrottler) updateLag(ctx context.Context) error {
var newLag sqlorig.NullInt64
err := t.DB.QueryRowContext(ctx, t.config.Query).Scan(&newLag)
if err == sqlorig.ErrNoRows {
return nil
}
if err != nil {
return err
}
if !newLag.Valid {
return nil
}
t.lag = int(newLag.Int64)
return nil
}