Skip to content

Commit

Permalink
Fix zookeeper ip address change with tls (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
secwall authored Oct 22, 2024
1 parent 493f11e commit 80517a9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 46 deletions.
4 changes: 2 additions & 2 deletions internal/dcs/zk.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *log.Logg
var ec <-chan zk.Event
var err error
var operation func() error
hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, logger)
hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, !config.UseSSL, logger)
if config.UseSSL {
if config.CACert == "" || config.KeyFile == "" || config.CertFile == "" {
return nil, fmt.Errorf("zookeeper ssl not configured, fill ca_cert/key_file/cert_file in config or disable use_ssl flag")
Expand All @@ -82,7 +82,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *log.Logg
return nil, err
}
baseDialer := net.Dialer{Timeout: config.SessionTimeout}
dialer, err := GetTLSDialer(config.Hosts, &baseDialer, tlsConfig)
dialer, err := GetTLSDialer(&baseDialer, tlsConfig)
if err != nil {
return nil, err
}
Expand Down
10 changes: 8 additions & 2 deletions internal/dcs/zk_host_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type zkhost struct {
type RandomHostProvider struct {
ctx context.Context
hosts sync.Map
useAddrs bool
hostsKeys []string
tried map[string]struct{}
logger *log.Logger
Expand All @@ -28,7 +29,7 @@ type RandomHostProvider struct {
resolver *net.Resolver
}

func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, logger *log.Logger) *RandomHostProvider {
func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, useAddrs bool, logger *log.Logger) *RandomHostProvider {
return &RandomHostProvider{
ctx: ctx,
lookupTTL: config.LookupTTL,
Expand All @@ -38,6 +39,7 @@ func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig
tried: make(map[string]struct{}),
hosts: sync.Map{},
resolver: &net.Resolver{},
useAddrs: useAddrs,
}
}

Expand Down Expand Up @@ -147,7 +149,11 @@ func (rhp *RandomHostProvider) Next() (server string, retryStart bool) {
zhost := host.(zkhost)

if len(zhost.resolved) > 0 {
ret = zhost.resolved[rand.Intn(len(zhost.resolved))]
if rhp.useAddrs {
ret = zhost.resolved[rand.Intn(len(zhost.resolved))]
} else {
ret = selected
}
}
}

Expand Down
44 changes: 2 additions & 42 deletions internal/dcs/zk_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,13 @@ package dcs
import (
"crypto/tls"
"crypto/x509"
"errors"
"net"
"os"
"time"

"github.com/go-zookeeper/zk"
)

// TODO: if pr https://github.com/go-zookeeper/zk/pull/106 will be merged
// remove this file and use same functions from go-zookeeper/zk
func addrsByHostname(server string) ([]string, error) {
res := []string{}
host, port, err := net.SplitHostPort(server)
if err != nil {
return nil, err
}
addrs, err := net.LookupHost(host)
if err != nil {
return nil, err
}
for _, addr := range addrs {
res = append(res, net.JoinHostPort(addr, port))
}
return res, nil
}

func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error) {
rootCABytes, err := os.ReadFile(rootCAFile)
if err != nil {
Expand All @@ -52,29 +33,8 @@ func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error)
}, nil
}

func GetTLSDialer(servers []string, dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) {
if len(servers) == 0 {
return nil, errors.New("zk: server list must not be empty")
}
srvs := zk.FormatServers(servers)

addrToHostname := map[string]string{}

for _, server := range srvs {
ips, err := addrsByHostname(server)
if err != nil {
return nil, err
}
for _, ip := range ips {
addrToHostname[ip] = server
}
}

func GetTLSDialer(dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) {
return func(network, address string, _ time.Duration) (net.Conn, error) {
server, ok := addrToHostname[address]
if !ok {
server = address
}
return tls.DialWithDialer(dialer, network, server, tlsConfig)
return tls.DialWithDialer(dialer, network, address, tlsConfig)
}, nil
}

0 comments on commit 80517a9

Please sign in to comment.