Skip to content

Commit

Permalink
feat: add mTLS configuration support for Redis client connections (#918)
Browse files Browse the repository at this point in the history
  • Loading branch information
rimvydascivilis authored Jan 21, 2025
1 parent 79fafb5 commit 96d8ebe
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 38 deletions.
37 changes: 34 additions & 3 deletions internal/client/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"net"
"os"
"regexp"
"strconv"
"strings"
Expand All @@ -22,7 +24,13 @@ type Redis struct {
protoWriter *proto.Writer
}

func NewRedisClient(ctx context.Context, address string, username string, password string, Tls bool, replica bool) *Redis {
type TlsConfig struct {
CACertFilePath string `mapstructure:"ca_cert" default:""`
CertFilePath string `mapstructure:"cert" default:""`
KeyFilePath string `mapstructure:"key" default:""`
}

func NewRedisClient(ctx context.Context, address string, username string, password string, Tls bool, tlsConfig TlsConfig, replica bool) *Redis {
r := new(Redis)
var conn net.Conn
var dialer = &net.Dialer{
Expand All @@ -35,7 +43,7 @@ func NewRedisClient(ctx context.Context, address string, username string, passwo
if Tls {
tlsDialer := &tls.Dialer{
NetDialer: dialer,
Config: &tls.Config{InsecureSkipVerify: true},
Config: getTlsConfig(tlsConfig),
}
conn, err = tlsDialer.DialContext(ctxWithDeadline, "tcp", address)
} else {
Expand Down Expand Up @@ -75,12 +83,35 @@ func NewRedisClient(ctx context.Context, address string, username string, passwo
if replica {
replicaInfo := getReplicaAddr(reply, address)
log.Infof("best replica: %s", replicaInfo.BestReplica)
r = NewRedisClient(ctx, replicaInfo.BestReplica, username, password, Tls, false)
r = NewRedisClient(ctx, replicaInfo.BestReplica, username, password, Tls, tlsConfig, false)
}

return r
}

func getTlsConfig(tlsConfig TlsConfig) *tls.Config {
if tlsConfig.CACertFilePath == "" || tlsConfig.CertFilePath == "" || tlsConfig.KeyFilePath == "" {
return &tls.Config{InsecureSkipVerify: true}
}

// Use mutual authentication (mTLS)
cert, err := tls.LoadX509KeyPair(tlsConfig.CertFilePath, tlsConfig.KeyFilePath)
if err != nil {
log.Panicf("load tls cert failed. cert=[%s], key=[%s], err=[%v]", tlsConfig.CertFilePath, tlsConfig.KeyFilePath, err)
}
caCert, err := os.ReadFile(tlsConfig.CACertFilePath)
if err != nil {
log.Panicf("read ca cert failed. ca_cert=[%s], err=[%v]", tlsConfig.CACertFilePath, err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
return &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
}

type Replica struct {
Addr string
Offset string
Expand Down
13 changes: 7 additions & 6 deletions internal/client/sentinel.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ import (
)

type SentinelOptions struct {
MasterName string `mapstructure:"master_name" default:""`
Address string `mapstructure:"address" default:""`
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
MasterName string `mapstructure:"master_name" default:""`
Address string `mapstructure:"address" default:""`
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
TlsConfig TlsConfig `mapstructure:"tls_config" default:"{}"`
}

func FetchAddressFromSentinel(opts *SentinelOptions) string {
log.Infof("fetching master address from sentinel. sentinel address: %s, master name: %s", opts.Address, opts.MasterName)

ctx := context.Background()
c := NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, false)
c := NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, false)
defer c.Close()
c.Send("SENTINEL", "GET-MASTER-ADDR-BY-NAME", opts.MasterName)
hostport := ArrayString(c.Receive())
Expand Down
2 changes: 1 addition & 1 deletion internal/reader/scan_cluster_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type scanClusterReader struct {
}

func NewScanClusterReader(ctx context.Context, opts *ScanReaderOptions) Reader {
addresses, _ := utils.GetRedisClusterNodes(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.PreferReplica)
addresses, _ := utils.GetRedisClusterNodes(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, opts.PreferReplica)

rd := &scanClusterReader{}
for _, address := range addresses {
Expand Down
29 changes: 15 additions & 14 deletions internal/reader/scan_standalone_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ import (
)

type ScanReaderOptions struct {
Cluster bool `mapstructure:"cluster" default:"false"`
Address string `mapstructure:"address" default:""`
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
Scan bool `mapstructure:"scan" default:"true"`
KSN bool `mapstructure:"ksn" default:"false"`
DBS []int `mapstructure:"dbs"`
PreferReplica bool `mapstructure:"prefer_replica" default:"false"`
Count int `mapstructure:"count" default:"1"`
Cluster bool `mapstructure:"cluster" default:"false"`
Address string `mapstructure:"address" default:""`
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
TlsConfig client.TlsConfig `mapstructure:"tls_config" default:"{}"`
Scan bool `mapstructure:"scan" default:"true"`
KSN bool `mapstructure:"ksn" default:"false"`
DBS []int `mapstructure:"dbs"`
PreferReplica bool `mapstructure:"prefer_replica" default:"false"`
Count int `mapstructure:"count" default:"1"`
}

type dbKey struct {
Expand Down Expand Up @@ -63,7 +64,7 @@ type scanStandaloneReader struct {
func NewScanStandaloneReader(ctx context.Context, opts *ScanReaderOptions) Reader {
r := new(scanStandaloneReader)
// dbs
c := client.NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.PreferReplica)
c := client.NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, opts.PreferReplica)
if len(opts.DBS) != 0 {
r.dbs = opts.DBS
} else if c.IsCluster() { // not use opts.Cluster, because user may use standalone mode to scan a cluster node
Expand Down Expand Up @@ -99,7 +100,7 @@ func (r *scanStandaloneReader) StartRead(ctx context.Context) []chan *entry.Entr
}

func (r *scanStandaloneReader) subscript() {
c := client.NewRedisClient(r.ctx, r.opts.Address, r.opts.Username, r.opts.Password, r.opts.Tls, r.opts.PreferReplica)
c := client.NewRedisClient(r.ctx, r.opts.Address, r.opts.Username, r.opts.Password, r.opts.Tls, r.opts.TlsConfig, r.opts.PreferReplica)
if len(r.dbs) == 0 {
c.Send("psubscribe", "__keyevent@*__:*")
} else {
Expand Down Expand Up @@ -148,7 +149,7 @@ func (r *scanStandaloneReader) subscript() {
}

func (r *scanStandaloneReader) scan() {
c := client.NewRedisClient(r.ctx, r.opts.Address, r.opts.Username, r.opts.Password, r.opts.Tls, r.opts.PreferReplica)
c := client.NewRedisClient(r.ctx, r.opts.Address, r.opts.Username, r.opts.Password, r.opts.Tls, r.opts.TlsConfig, r.opts.PreferReplica)
defer c.Close()
for _, dbId := range r.dbs {
if dbId != 0 {
Expand Down Expand Up @@ -193,7 +194,7 @@ func (r *scanStandaloneReader) scan() {

func (r *scanStandaloneReader) dump() {
nowDbId := 0
r.dumpClient = client.NewRedisClient(r.ctx, r.opts.Address, r.opts.Username, r.opts.Password, r.opts.Tls, r.opts.PreferReplica)
r.dumpClient = client.NewRedisClient(r.ctx, r.opts.Address, r.opts.Username, r.opts.Password, r.opts.Tls, r.opts.TlsConfig, r.opts.PreferReplica)
// Support prefer_replica=true in both Cluster and Standalone mode
if r.opts.PreferReplica {
r.dumpClient.Do("READONLY")
Expand Down
2 changes: 1 addition & 1 deletion internal/reader/sync_cluster_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type syncClusterReader struct {
}

func NewSyncClusterReader(ctx context.Context, opts *SyncReaderOptions) Reader {
addresses, _ := utils.GetRedisClusterNodes(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.PreferReplica)
addresses, _ := utils.GetRedisClusterNodes(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, opts.PreferReplica)
log.Debugf("get redis cluster nodes:")
for _, address := range addresses {
log.Debugf("%s", address)
Expand Down
5 changes: 3 additions & 2 deletions internal/reader/sync_standalone_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"bufio"
"bytes"
"context"
"errors"
"encoding/json"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -34,6 +34,7 @@ type SyncReaderOptions struct {
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
TlsConfig client.TlsConfig `mapstructure:"tls_config" default:"{}"`
SyncRdb bool `mapstructure:"sync_rdb" default:"true"`
SyncAof bool `mapstructure:"sync_aof" default:"true"`
PreferReplica bool `mapstructure:"prefer_replica" default:"false"`
Expand Down Expand Up @@ -111,7 +112,7 @@ type syncStandaloneReader struct {
func NewSyncStandaloneReader(ctx context.Context, opts *SyncReaderOptions) Reader {
r := new(syncStandaloneReader)
r.opts = opts
r.client = client.NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.PreferReplica)
r.client = client.NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, opts.PreferReplica)
r.stat.Name = "reader_" + strings.Replace(opts.Address, ":", "_", -1)
r.stat.Address = opts.Address
r.stat.Status = kHandShake
Expand Down
4 changes: 2 additions & 2 deletions internal/utils/cluster_nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"RedisShake/internal/log"
)

func GetRedisClusterNodes(ctx context.Context, address string, username string, password string, Tls bool, perferReplica bool) (addresses []string, slots [][]int) {
c := client.NewRedisClient(ctx, address, username, password, Tls, false)
func GetRedisClusterNodes(ctx context.Context, address string, username string, password string, Tls bool, tlsConfig client.TlsConfig, perferReplica bool) (addresses []string, slots [][]int) {
c := client.NewRedisClient(ctx, address, username, password, Tls, tlsConfig, false)
reply := c.DoWithStringReply("cluster", "nodes")
reply = strings.TrimSpace(reply)
slotsCount := 0
Expand Down
2 changes: 1 addition & 1 deletion internal/writer/redis_cluster_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (r *RedisClusterWriter) Close() {
}

func (r *RedisClusterWriter) loadClusterNodes(ctx context.Context, opts *RedisWriterOptions) {
addresses, slots := utils.GetRedisClusterNodes(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, false)
addresses, slots := utils.GetRedisClusterNodes(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, false)
r.addresses = addresses
for i, address := range addresses {
theOpts := *opts
Expand Down
17 changes: 9 additions & 8 deletions internal/writer/redis_standalone_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ import (
)

type RedisWriterOptions struct {
Cluster bool `mapstructure:"cluster" default:"false"`
Address string `mapstructure:"address" default:""`
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
OffReply bool `mapstructure:"off_reply" default:"false"`
Sentinel client.SentinelOptions `mapstructure:"sentinel"`
Cluster bool `mapstructure:"cluster" default:"false"`
Address string `mapstructure:"address" default:""`
Username string `mapstructure:"username" default:""`
Password string `mapstructure:"password" default:""`
Tls bool `mapstructure:"tls" default:"false"`
TlsConfig client.TlsConfig `mapstructure:"tls_config" default:"{}"`
OffReply bool `mapstructure:"off_reply" default:"false"`
Sentinel client.SentinelOptions `mapstructure:"sentinel"`
}

type redisStandaloneWriter struct {
Expand All @@ -49,7 +50,7 @@ func NewRedisStandaloneWriter(ctx context.Context, opts *RedisWriterOptions) Wri
rw := new(redisStandaloneWriter)
rw.address = opts.Address
rw.stat.Name = "writer_" + strings.Replace(opts.Address, ":", "_", -1)
rw.client = client.NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, false)
rw.client = client.NewRedisClient(ctx, opts.Address, opts.Username, opts.Password, opts.Tls, opts.TlsConfig, false)
rw.ch = make(chan *entry.Entry, config.Opt.Advanced.PipelineCountLimit)
if opts.OffReply {
log.Infof("turn off the reply of write")
Expand Down

0 comments on commit 96d8ebe

Please sign in to comment.