Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support rr_dns lookup host #113

Merged
merged 12 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 143 additions & 77 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

"github.com/dustin/go-humanize"
Expand Down Expand Up @@ -71,7 +73,7 @@ var (
globalConsoleDisplay bool
globalErrorsOnly bool
globalStatusCodes []int
globalConnStats []*ConnStats
globalConnStats atomic.Pointer[[]*ConnStats]
log2 *logrus.Logger
globalHostBalance string
)
Expand Down Expand Up @@ -305,14 +307,19 @@ func getHealthCheckURL(endpoint, healthCheckPath string, healthCheckPort int) (s
}

// healthCheck - background routine which checks if a backend is up or down.
func (b *Backend) healthCheck() {
func (b *Backend) healthCheck(ctxt context.Context) {
ticker := time.NewTicker(b.healthCheckDuration)
defer ticker.Stop()
for {
err := b.doHealthCheck()
if err != nil {
console.Fatalln(err)
select {
case <-ctxt.Done():
return
case <-ticker.C:
err := b.doHealthCheck()
if err != nil {
console.Errorln(err)
}
jiuker marked this conversation as resolved.
Show resolved Hide resolved
}

time.Sleep(b.healthCheckDuration)
}
}

Expand Down Expand Up @@ -393,7 +400,7 @@ func (b *Backend) updateCallStats(t shortTraceMsg) {
b.Stats.MinLatency = time.Duration(int64(math.Min(float64(b.Stats.MinLatency), float64(t.CallStats.Latency))))
b.Stats.Rx += int64(t.CallStats.Rx)
b.Stats.Tx += int64(t.CallStats.Tx)
for _, c := range globalConnStats {
for _, c := range *globalConnStats.Load() {
if c == nil {
continue
}
Expand All @@ -410,36 +417,96 @@ func (b *Backend) updateCallStats(t shortTraceMsg) {
}

type multisite struct {
sites []*site
sites atomic.Pointer[[]*site]
healthCanceler context.CancelFunc
}

func (m *multisite) renewSite(ctx *cli.Context, healthCheckPath string, healthReadCheckPath string, healthCheckPort int, healthCheckDuration, healthCheckTimeout time.Duration) {
jiuker marked this conversation as resolved.
Show resolved Hide resolved
ctxt, cancel := context.WithCancel(context.Background())
var sites []*site
for i, siteStrs := range ctx.Args() {
if i == len(ctx.Args())-1 {
healthCheckPath = healthReadCheckPath
}
site := configureSite(ctxt, ctx, i+1, strings.Split(siteStrs, ","), healthCheckPath, healthCheckPort, healthCheckDuration, healthCheckTimeout)
sites = append(sites, site)
}
m.sites.Store(&sites)
// cancel the previous health checker
if m.healthCanceler != nil {
m.healthCanceler()
}
m.healthCanceler = cancel
}

func (m *multisite) displayUI(show bool) {
if !show {
return
}
go func() {
// Clear screen before we start the table UI
clearScreen()

ticker := time.NewTicker(500 * time.Millisecond)
for range ticker.C {
m.populate()
}
}()
}

func (m *multisite) populate(cellText [][]string) {
for i, site := range m.sites {
func (m *multisite) populate() {
sites := *m.sites.Load()

dspOrder := []col{colGreen} // Header
for i := 0; i < len(sites); i++ {
for range sites[i].backends {
dspOrder = append(dspOrder, colGrey)
}
}
var printColors []*color.Color
for _, c := range dspOrder {
printColors = append(printColors, getPrintCol(c))
}

tbl := console.NewTable(printColors, []bool{
false, false, false, false, false, false,
false, false, false, false, false,
}, 0)

cellText := make([][]string, len(dspOrder))
cellText[0] = headers
for i, site := range sites {
for j, b := range site.backends {
b.Stats.Lock()
minLatency := "0s"
maxLatency := "0s"
if b.Stats.MaxLatency > 0 {
minLatency = fmt.Sprintf("%2s", b.Stats.MinLatency.Round(time.Microsecond))
maxLatency = fmt.Sprintf("%2s", b.Stats.MaxLatency.Round(time.Microsecond))
}
cellText[i*len(site.backends)+j][0] = humanize.Ordinal(b.siteNumber)
cellText[i*len(site.backends)+j][1] = b.endpoint
cellText[i*len(site.backends)+j][2] = b.getServerStatus()
cellText[i*len(site.backends)+j][3] = strconv.FormatInt(b.Stats.TotCalls, 10)
cellText[i*len(site.backends)+j][4] = strconv.FormatInt(b.Stats.TotCallFailures, 10)
cellText[i*len(site.backends)+j][5] = humanize.IBytes(uint64(b.Stats.Rx))
cellText[i*len(site.backends)+j][6] = humanize.IBytes(uint64(b.Stats.Tx))
cellText[i*len(site.backends)+j][7] = b.Stats.CumDowntime.Round(time.Microsecond).String()
cellText[i*len(site.backends)+j][8] = b.Stats.LastDowntime.Round(time.Microsecond).String()
cellText[i*len(site.backends)+j][9] = minLatency
cellText[i*len(site.backends)+j][10] = maxLatency
cellText[i*len(site.backends)+j+1] = []string{
humanize.Ordinal(b.siteNumber),
b.endpoint,
b.getServerStatus(),
strconv.FormatInt(b.Stats.TotCalls, 10),
strconv.FormatInt(b.Stats.TotCallFailures, 10),
humanize.IBytes(uint64(b.Stats.Rx)),
humanize.IBytes(uint64(b.Stats.Tx)),
b.Stats.CumDowntime.Round(time.Microsecond).String(),
b.Stats.LastDowntime.Round(time.Microsecond).String(),
minLatency,
maxLatency,
}
b.Stats.Unlock()
}
}
console.RewindLines(len(cellText) + 2)
tbl.DisplayTable(cellText)
}

func (m *multisite) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", "SideKick") // indicate sidekick is serving
for _, s := range m.sites {
for _, s := range *m.sites.Load() {
if s.Online() {
if r.URL.Path == healthPath {
// Health check endpoint should return success
Expand Down Expand Up @@ -635,7 +702,7 @@ func newProxyDialContext(dialTimeout time.Duration) DialContext {
// tlsClientSessionCacheSize is the cache size for TLS client sessions.
const tlsClientSessionCacheSize = 100

func clientTransport(ctx *cli.Context, enableTLS bool) http.RoundTripper {
func clientTransport(ctx *cli.Context, enableTLS bool, hostName string) http.RoundTripper {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContextWithDNSCache(dnsCache, newProxyDialContext(10*time.Second)),
Expand Down Expand Up @@ -666,6 +733,7 @@ func clientTransport(ctx *cli.Context, enableTLS bool) http.RoundTripper {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
ClientSessionCache: tls.NewLRUClientSessionCache(tlsClientSessionCacheSize),
ServerName: hostName,
}
}

Expand Down Expand Up @@ -766,7 +834,7 @@ func IsLoopback(addr string) bool {
return net.ParseIP(host).IsLoopback()
}

func configureSite(ctx *cli.Context, siteNum int, siteStrs []string, healthCheckPath string, healthCheckPort int, healthCheckDuration, healthCheckTimeout time.Duration) *site {
func configureSite(ctxt context.Context, ctx *cli.Context, siteNum int, siteStrs []string, healthCheckPath string, healthCheckPort int, healthCheckDuration, healthCheckTimeout time.Duration) *site {
var endpoints []string

if ellipses.HasEllipses(siteStrs...) {
Expand All @@ -790,6 +858,25 @@ func configureSite(ctx *cli.Context, siteNum int, siteStrs []string, healthCheck
var backends []*Backend
var prevScheme string
var transport http.RoundTripper
var connStats []*ConnStats
var hostName string
if len(endpoints) == 1 && ctx.GlobalBool("rr-dns-mode") {
harshavardhana marked this conversation as resolved.
Show resolved Hide resolved
jiuker marked this conversation as resolved.
Show resolved Hide resolved
// guess it is LB config address
target, err := url.Parse(endpoints[0])
if err != nil {
console.Fatalln(fmt.Errorf("Unable to parse input arg %s: %s", endpoints[0], err))
}
hostName = target.Hostname()
ips, err := net.LookupHost(hostName)
if err != nil {
console.Fatalln(fmt.Errorf("Unable to lookup host %s", hostName))
}
// set the new endpoints
endpoints = []string{}
for _, ip := range ips {
endpoints = append(endpoints, strings.Replace(target.String(), hostName, ip, 1))
}
}
for _, endpoint := range endpoints {
endpoint = strings.TrimSuffix(endpoint, slashSeparator)
target, err := url.Parse(endpoint)
Expand All @@ -815,7 +902,7 @@ func configureSite(ctx *cli.Context, siteNum int, siteStrs []string, healthCheck
endpoint, ctx.App.Name))
}
if transport == nil {
transport = clientTransport(ctx, target.Scheme == "https")
transport = clientTransport(ctx, target.Scheme == "https", hostName)
}
// this is only used if r.RemoteAddr is localhost which means that
// sidekick endpoint being accessed is 127.0.0.x
Expand Down Expand Up @@ -843,12 +930,12 @@ func configureSite(ctx *cli.Context, siteNum int, siteStrs []string, healthCheck
backend := &Backend{siteNum, endpoint, proxy, &http.Client{
Transport: proxy.Transport,
}, 0, healthCheckURL, healthCheckDuration, healthCheckTimeout, &stats}
go backend.healthCheck()
go backend.healthCheck(ctxt)
proxy.ErrorHandler = backend.ErrorHandler
backends = append(backends, backend)
globalConnStats = append(globalConnStats, newConnStats(endpoint))
connStats = append(connStats, newConnStats(endpoint))
}

globalConnStats.Store(&connStats)
return &site{
backends: backends,
}
Expand Down Expand Up @@ -922,16 +1009,6 @@ func sidekickMain(ctx *cli.Context) {
healthReadCheckPath = slashSeparator + healthReadCheckPath
}

var sites []*site
for i, siteStrs := range ctx.Args() {
if i == len(ctx.Args())-1 {
healthCheckPath = healthReadCheckPath
}

site := configureSite(ctx, i+1, strings.Split(siteStrs, ","), healthCheckPath, healthCheckPort, healthCheckDuration, healthCheckTimeout)
sites = append(sites, site)
}

if globalConsoleDisplay {
console.SetColor("LogMsgType", color.New(color.FgHiMagenta))
console.SetColor("TraceMsgType", color.New(color.FgYellow))
Expand Down Expand Up @@ -960,42 +1037,9 @@ func sidekickMain(ctx *cli.Context) {
console.Fatalln(err)
}

m := &multisite{sites}
if !globalConsoleDisplay {
dspOrder := []col{colGreen} // Header
for i := 0; i < len(sites); i++ {
for range sites[i].backends {
dspOrder = append(dspOrder, colGrey)
}
}
var printColors []*color.Color
for _, c := range dspOrder {
printColors = append(printColors, getPrintCol(c))
}

tbl := console.NewTable(printColors, []bool{
false, false, false, false, false, false,
false, false, false, false, false,
}, 0)

cellText := make([][]string, len(dspOrder))
for i := range dspOrder {
cellText[i] = make([]string, len(headers))
}
cellText[0] = headers

go func() {
// Clear screen before we start the table UI
clearScreen()

ticker := time.NewTicker(500 * time.Millisecond)
for range ticker.C {
m.populate(cellText[1:])
console.RewindLines(len(cellText) + 2)
tbl.DisplayTable(cellText)
}
}()
}
m := &multisite{}
m.renewSite(ctx, healthCheckPath, healthReadCheckPath, healthCheckPort, healthCheckDuration, healthCheckTimeout)
m.displayUI(!globalConsoleDisplay)

router.PathPrefix(slashSeparator).Handler(m)
server := &http.Server{
Expand All @@ -1017,8 +1061,26 @@ func sidekickMain(ctx *cli.Context) {
}
server.TLSConfig = tlsConfig
}
if err := server.ListenAndServe(); err != nil {
console.Fatalln(err)
go func() {
if err := server.ListenAndServe(); err != nil {
console.Fatalln(err)
}
}()
osSignalChannel := make(chan os.Signal, 1)
signal.Notify(
osSignalChannel,
syscall.SIGTERM,
syscall.SIGINT,
syscall.SIGHUP,
)
for signal := range osSignalChannel {
switch signal {
case syscall.SIGHUP:
m.renewSite(ctx, healthCheckPath, healthReadCheckPath, healthCheckPort, healthCheckDuration, healthCheckTimeout)
default:
console.Infof("caught signal '%s'\n", signal)
os.Exit(1)
}
}
}

Expand Down Expand Up @@ -1065,6 +1127,10 @@ func main() {
Name: "insecure, i",
Usage: "disable TLS certificate verification",
},
cli.BoolFlag{
Name: "rr-dns-mode",
Usage: "enable round-robin DNS mode",
},
cli.BoolFlag{
Name: "log, l",
Usage: "enable logging",
Expand Down
2 changes: 1 addition & 1 deletion metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (c *sidekickCollector) Describe(ch chan<- *prometheus.Desc) {

// Collect is called by the Prometheus registry when collecting metrics.
func (c *sidekickCollector) Collect(ch chan<- prometheus.Metric) {
for _, c := range globalConnStats {
for _, c := range *globalConnStats.Load() {
if c == nil {
continue
}
Expand Down
Loading