From 507416b315e5b7caac7036e50a13e1080046779f Mon Sep 17 00:00:00 2001 From: secwall Date: Mon, 21 Oct 2024 10:01:37 +0200 Subject: [PATCH] Fix zookeeper ip address change with tls --- internal/dcs/zk.go | 4 +-- internal/dcs/zk_host_provider.go | 10 ++++++-- internal/dcs/zk_tls.go | 44 ++------------------------------ 3 files changed, 12 insertions(+), 46 deletions(-) diff --git a/internal/dcs/zk.go b/internal/dcs/zk.go index 3fa35506..3148326b 100644 --- a/internal/dcs/zk.go +++ b/internal/dcs/zk.go @@ -72,7 +72,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *log.Logg var ec <-chan zk.Event var err error var operation func() error - hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, logger) + hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, !config.UseSSL, logger) if config.UseSSL { if config.CACert == "" || config.KeyFile == "" || config.CertFile == "" { return nil, fmt.Errorf("zookeeper ssl not configured, fill ca_cert/key_file/cert_file in config or disable use_ssl flag") @@ -82,7 +82,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *log.Logg return nil, err } baseDialer := net.Dialer{Timeout: config.SessionTimeout} - dialer, err := GetTLSDialer(config.Hosts, &baseDialer, tlsConfig) + dialer, err := GetTLSDialer(&baseDialer, tlsConfig) if err != nil { return nil, err } diff --git a/internal/dcs/zk_host_provider.go b/internal/dcs/zk_host_provider.go index 9a50a475..1a2d3d7d 100644 --- a/internal/dcs/zk_host_provider.go +++ b/internal/dcs/zk_host_provider.go @@ -19,6 +19,7 @@ type zkhost struct { type RandomHostProvider struct { ctx context.Context hosts sync.Map + useAddrs bool hostsKeys []string tried map[string]struct{} logger *log.Logger @@ -28,7 +29,7 @@ type RandomHostProvider struct { resolver *net.Resolver } -func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, logger *log.Logger) *RandomHostProvider { +func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, useAddrs bool, logger *log.Logger) *RandomHostProvider { return &RandomHostProvider{ ctx: ctx, lookupTTL: config.LookupTTL, @@ -38,6 +39,7 @@ func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig tried: make(map[string]struct{}), hosts: sync.Map{}, resolver: &net.Resolver{}, + useAddrs: useAddrs, } } @@ -147,7 +149,11 @@ func (rhp *RandomHostProvider) Next() (server string, retryStart bool) { zhost := host.(zkhost) if len(zhost.resolved) > 0 { - ret = zhost.resolved[rand.Intn(len(zhost.resolved))] + if rhp.useAddrs { + ret = zhost.resolved[rand.Intn(len(zhost.resolved))] + } else { + ret = selected + } } } diff --git a/internal/dcs/zk_tls.go b/internal/dcs/zk_tls.go index b3d00c18..9fdadea7 100644 --- a/internal/dcs/zk_tls.go +++ b/internal/dcs/zk_tls.go @@ -3,7 +3,6 @@ package dcs import ( "crypto/tls" "crypto/x509" - "errors" "net" "os" "time" @@ -11,24 +10,6 @@ import ( "github.com/go-zookeeper/zk" ) -// TODO: if pr https://github.com/go-zookeeper/zk/pull/106 will be merged -// remove this file and use same functions from go-zookeeper/zk -func addrsByHostname(server string) ([]string, error) { - res := []string{} - host, port, err := net.SplitHostPort(server) - if err != nil { - return nil, err - } - addrs, err := net.LookupHost(host) - if err != nil { - return nil, err - } - for _, addr := range addrs { - res = append(res, net.JoinHostPort(addr, port)) - } - return res, nil -} - func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error) { rootCABytes, err := os.ReadFile(rootCAFile) if err != nil { @@ -52,29 +33,8 @@ func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error) }, nil } -func GetTLSDialer(servers []string, dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) { - if len(servers) == 0 { - return nil, errors.New("zk: server list must not be empty") - } - srvs := zk.FormatServers(servers) - - addrToHostname := map[string]string{} - - for _, server := range srvs { - ips, err := addrsByHostname(server) - if err != nil { - return nil, err - } - for _, ip := range ips { - addrToHostname[ip] = server - } - } - +func GetTLSDialer(dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) { return func(network, address string, _ time.Duration) (net.Conn, error) { - server, ok := addrToHostname[address] - if !ok { - server = address - } - return tls.DialWithDialer(dialer, network, server, tlsConfig) + return tls.DialWithDialer(dialer, network, address, tlsConfig) }, nil }