Skip to content

Commit

Permalink
support mTLS between nodes (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Nov 28, 2024
1 parent 2a4227f commit cca6da7
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cmd/gorse-in-one/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ var oneCommand = &cobra.Command{
// Start worker
workerJobs, _ := cmd.PersistentFlags().GetInt("recommend-jobs")
w := worker.NewWorker(conf.Master.Host, conf.Master.Port, conf.Master.Host,
0, workerJobs, "", managedMode)
0, workerJobs, "", managedMode, nil)
go func() {
w.SetOneMode(m.Settings)
w.Serve()
Expand Down
24 changes: 23 additions & 1 deletion cmd/gorse-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/spf13/cobra"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/cmd/version"
"github.com/zhenghaoz/gorse/protocol"
"github.com/zhenghaoz/gorse/server"
"go.uber.org/zap"
)
Expand All @@ -46,7 +47,25 @@ var serverCommand = &cobra.Command{
httpPort, _ := cmd.PersistentFlags().GetInt("http-port")
httpHost, _ := cmd.PersistentFlags().GetString("http-host")
cachePath, _ := cmd.PersistentFlags().GetString("cache-path")
s := server.NewServer(masterHost, masterPort, httpHost, httpPort, cachePath)
caFile, _ := cmd.PersistentFlags().GetString("ssl-ca")
certFile, _ := cmd.PersistentFlags().GetString("ssl-cert")
keyFile, _ := cmd.PersistentFlags().GetString("ssl-key")
var tlsConfig *protocol.TLSConfig
if caFile != "" && certFile != "" && keyFile != "" {
tlsConfig = &protocol.TLSConfig{
SSLCA: caFile,
SSLCert: certFile,
SSLKey: keyFile,
}
} else if caFile == "" && certFile == "" && keyFile == "" {
tlsConfig = nil
} else {
log.Logger().Fatal("incomplete SSL configuration",
zap.String("ssl_ca", caFile),
zap.String("ssl_cert", certFile),
zap.String("ssl_key", keyFile))
}
s := server.NewServer(masterHost, masterPort, httpHost, httpPort, cachePath, tlsConfig)

// stop server
done := make(chan struct{})
Expand Down Expand Up @@ -74,6 +93,9 @@ func init() {
serverCommand.PersistentFlags().String("http-host", "127.0.0.1", "port for RESTful APIs and Prometheus metrics export")
serverCommand.PersistentFlags().Bool("debug", false, "use debug log mode")
serverCommand.PersistentFlags().String("cache-path", "server_cache.data", "path of cache file")
serverCommand.PersistentFlags().String("ssl-ca", "", "path of SSL CA")
serverCommand.PersistentFlags().String("ssl-cert", "", "path of SSL certificate")
serverCommand.PersistentFlags().String("ssl-key", "", "path of SSL key")
}

func main() {
Expand Down
24 changes: 23 additions & 1 deletion cmd/gorse-worker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/spf13/cobra"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/cmd/version"
"github.com/zhenghaoz/gorse/protocol"
"github.com/zhenghaoz/gorse/worker"
"go.uber.org/zap"
)
Expand All @@ -45,7 +46,25 @@ var workerCommand = &cobra.Command{
log.SetLogger(cmd.PersistentFlags(), debug)
// create worker
cachePath, _ := cmd.PersistentFlags().GetString("cache-path")
w := worker.NewWorker(masterHost, masterPort, httpHost, httpPort, workingJobs, cachePath, managedModel)
caFile, _ := cmd.PersistentFlags().GetString("ssl-ca")
certFile, _ := cmd.PersistentFlags().GetString("ssl-cert")
keyFile, _ := cmd.PersistentFlags().GetString("ssl-key")
var tlsConfig *protocol.TLSConfig
if caFile != "" && certFile != "" && keyFile != "" {
tlsConfig = &protocol.TLSConfig{
SSLCA: caFile,
SSLCert: certFile,
SSLKey: keyFile,
}
} else if caFile == "" && certFile == "" && keyFile == "" {
tlsConfig = nil
} else {
log.Logger().Fatal("incomplete SSL configuration",
zap.String("ssl_ca", caFile),
zap.String("ssl_cert", certFile),
zap.String("ssl_key", keyFile))
}
w := worker.NewWorker(masterHost, masterPort, httpHost, httpPort, workingJobs, cachePath, managedModel, tlsConfig)
w.Serve()
},
}
Expand All @@ -61,6 +80,9 @@ func init() {
workerCommand.PersistentFlags().Bool("managed", false, "enable managed mode")
workerCommand.PersistentFlags().IntP("jobs", "j", 1, "number of working jobs.")
workerCommand.PersistentFlags().String("cache-path", "worker_cache.data", "path of cache file")
workerCommand.PersistentFlags().String("ssl-ca", "", "path of SSL CA")
workerCommand.PersistentFlags().String("ssl-cert", "", "path to SSL certificate")
workerCommand.PersistentFlags().String("ssl-key", "", "path to SSL key")
}

func main() {
Expand Down
8 changes: 8 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ type MySQLConfig struct {
type MasterConfig struct {
Port int `mapstructure:"port" validate:"gte=0"` // master port
Host string `mapstructure:"host"` // master host
SSLMode bool `mapstructure:"ssl_mode"` // enable SSL mode
SSLCA string `mapstructure:"ssl_ca"` // SSL CA file
SSLCert string `mapstructure:"ssl_cert"` // SSL certificate file
SSLKey string `mapstructure:"ssl_key"` // SSL key file
HttpPort int `mapstructure:"http_port" validate:"gte=0"` // HTTP port
HttpHost string `mapstructure:"http_host"` // HTTP host
HttpCorsDomains []string `mapstructure:"http_cors_domains"` // add allowed cors domains
Expand Down Expand Up @@ -569,6 +573,10 @@ func LoadConfig(path string, oneModel bool) (*Config, error) {
{"database.data_table_prefix", "GORSE_DATA_TABLE_PREFIX"},
{"master.port", "GORSE_MASTER_PORT"},
{"master.host", "GORSE_MASTER_HOST"},
{"master.ssl_mode", "GORSE_MASTER_SSL_MODE"},
{"master.ssl_ca", "GORSE_MASTER_SSL_CA"},
{"master.ssl_cert", "GORSE_MASTER_SSL_CERT"},
{"master.ssl_key", "GORSE_MASTER_SSL_KEY"},
{"master.http_port", "GORSE_MASTER_HTTP_PORT"},
{"master.http_host", "GORSE_MASTER_HTTP_HOST"},
{"master.n_jobs", "GORSE_MASTER_JOBS"},
Expand Down
12 changes: 12 additions & 0 deletions config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ port = 8086
# gRPC host of the master node. The default values is "0.0.0.0".
host = "0.0.0.0"

# Enable SSL for the gRPC communication. The default value is false.
ssl_mode = false

# SSL certification authority for the gRPC communication.
ssl_ca = ""

# SSL certification for the gRPC communication.
ssl_cert = ""

# SSL certification key for the gRPC communication.
ssl_key = ""

# HTTP port of the master node. The default values is 8088.
http_port = 8088

Expand Down
16 changes: 16 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func TestUnmarshal(t *testing.T) {
data, err := os.ReadFile("config.toml")
assert.NoError(t, err)
text := string(data)
text = strings.Replace(text, "ssl_mode = false", "ssl_mode = true", -1)
text = strings.Replace(text, "ssl_ca = \"\"", "ssl_ca = \"ca.pem\"", -1)
text = strings.Replace(text, "ssl_cert = \"\"", "ssl_cert = \"cert.pem\"", -1)
text = strings.Replace(text, "ssl_key = \"\"", "ssl_key = \"key.pem\"", -1)
text = strings.Replace(text, "dashboard_user_name = \"\"", "dashboard_user_name = \"admin\"", -1)
text = strings.Replace(text, "dashboard_password = \"\"", "dashboard_password = \"password\"", -1)
text = strings.Replace(text, "admin_api_key = \"\"", "admin_api_key = \"super_api_key\"", -1)
Expand Down Expand Up @@ -69,6 +73,10 @@ func TestUnmarshal(t *testing.T) {
// [master]
assert.Equal(t, 8086, config.Master.Port)
assert.Equal(t, "0.0.0.0", config.Master.Host)
assert.Equal(t, true, config.Master.SSLMode)
assert.Equal(t, "ca.pem", config.Master.SSLCA)
assert.Equal(t, "cert.pem", config.Master.SSLCert)
assert.Equal(t, "key.pem", config.Master.SSLKey)
assert.Equal(t, 8088, config.Master.HttpPort)
assert.Equal(t, "0.0.0.0", config.Master.HttpHost)
assert.Equal(t, []string{".*"}, config.Master.HttpCorsDomains)
Expand Down Expand Up @@ -181,6 +189,10 @@ func TestBindEnv(t *testing.T) {
{"GORSE_CACHE_TABLE_PREFIX", "gorse_cache_"},
{"GORSE_MASTER_PORT", "123"},
{"GORSE_MASTER_HOST", "<master_host>"},
{"GORSE_MASTER_SSL_MODE", "true"},
{"GORSE_MASTER_SSL_CA", "ca.pem"},
{"GORSE_MASTER_SSL_CERT", "cert.pem"},
{"GORSE_MASTER_SSL_KEY", "key.pem"},
{"GORSE_MASTER_HTTP_PORT", "456"},
{"GORSE_MASTER_HTTP_HOST", "<master_http_host>"},
{"GORSE_MASTER_JOBS", "789"},
Expand Down Expand Up @@ -209,6 +221,10 @@ func TestBindEnv(t *testing.T) {
assert.Equal(t, "gorse_data_", config.Database.DataTablePrefix)
assert.Equal(t, 123, config.Master.Port)
assert.Equal(t, "<master_host>", config.Master.Host)
assert.Equal(t, true, config.Master.SSLMode)
assert.Equal(t, "ca.pem", config.Master.SSLCA)
assert.Equal(t, "cert.pem", config.Master.SSLCert)
assert.Equal(t, "key.pem", config.Master.SSLKey)
assert.Equal(t, 456, config.Master.HttpPort)
assert.Equal(t, "<master_http_host>", config.Master.HttpHost)
assert.Equal(t, 789, config.Master.NumJobs)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ require (
github.com/klauspost/cpuid/v2 v2.2.3
github.com/lafikl/consistent v0.0.0-20220512074542-bdd3606bfc3e
github.com/lib/pq v1.10.6
github.com/madflojo/testcerts v1.3.0
github.com/mailru/go-clickhouse/v2 v2.0.1-0.20221121001540-b259988ad8e5
github.com/mitchellh/mapstructure v1.5.0
github.com/orcaman/concurrent-map v1.0.0
Expand Down Expand Up @@ -63,6 +64,7 @@ require (
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e
golang.org/x/oauth2 v0.22.0
google.golang.org/grpc v1.67.1
google.golang.org/grpc/security/advancedtls v1.0.0
google.golang.org/protobuf v1.35.1
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v2 v2.4.0
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs=
github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/madflojo/testcerts v1.3.0 h1:H6r7WlzfeLqzcuOglfAlnj5Rkt5iQoH1ctTi7FsLOdE=
github.com/madflojo/testcerts v1.3.0/go.mod h1:MW8sh39gLnkKh4K0Nc55AyHEDl9l/FBLDUsQhpmkuo0=
github.com/magiconair/properties v1.8.6 h1:5ibWZ6iY0NctNGWo87LalDlEZ6R41TqbbDamhfG/Qzo=
github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
Expand Down Expand Up @@ -1128,6 +1130,10 @@ google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v0.0.0-20200910201057-6591123024b3/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b h1:NuxyvVZoDfHZwYW9LD4GJiF5/nhiSyP4/InTrvw9Ibk=
google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b/go.mod h1:IBqQ7wSUJ2Ep09a8rMWFsg4fmI2r38zwsq8a0GgxXpM=
google.golang.org/grpc/security/advancedtls v1.0.0 h1:/KQ7VP/1bs53/aopk9QhuPyFAp9Dm9Ejix3lzYkCrDA=
google.golang.org/grpc/security/advancedtls v1.0.0/go.mod h1:o+s4go+e1PJ2AjuQMY5hU82W7lDlefjJA6FqEHRVHWk=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
Expand Down
20 changes: 18 additions & 2 deletions master/master.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,28 @@ func (m *Master) Serve() {
go func() {
log.Logger().Info("start rpc server",
zap.String("host", m.Config.Master.Host),
zap.Int("port", m.Config.Master.Port))
zap.Int("port", m.Config.Master.Port),
zap.Bool("ssl_mode", m.Config.Master.SSLMode),
zap.String("ssl_ca", m.Config.Master.SSLCA),
zap.String("ssl_cert", m.Config.Master.SSLCert),
zap.String("ssl_key", m.Config.Master.SSLKey))
opts := []grpc.ServerOption{grpc.MaxSendMsgSize(math.MaxInt)}
if m.Config.Master.SSLMode {
c, err := protocol.NewServerCreds(&protocol.TLSConfig{
SSLCA: m.Config.Master.SSLCA,
SSLCert: m.Config.Master.SSLCert,
SSLKey: m.Config.Master.SSLKey,
})
if err != nil {
log.Logger().Fatal("failed to load server TLS", zap.Error(err))
}
opts = append(opts, grpc.Creds(c))
}
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", m.Config.Master.Host, m.Config.Master.Port))
if err != nil {
log.Logger().Fatal("failed to listen", zap.Error(err))
}
m.grpcServer = grpc.NewServer(grpc.MaxSendMsgSize(math.MaxInt))
m.grpcServer = grpc.NewServer(opts...)
protocol.RegisterMasterServer(m.grpcServer, m)
if err = m.grpcServer.Serve(lis); err != nil {
log.Logger().Fatal("failed to start rpc server", zap.Error(err))
Expand Down
84 changes: 84 additions & 0 deletions master/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ import (
"context"
"encoding/json"
"net"
"os"
"path/filepath"
"testing"
"time"

"github.com/jellydator/ttlcache/v3"
"github.com/madflojo/testcerts"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base/progress"
"github.com/zhenghaoz/gorse/config"
Expand Down Expand Up @@ -86,6 +89,26 @@ func (m *mockMasterRPC) Start(t *testing.T) {
assert.NoError(t, err)
}

func (m *mockMasterRPC) StartTLS(t *testing.T, o *protocol.TLSConfig) {
m.ttlCache = ttlcache.New(ttlcache.WithTTL[string, *Node](time.Second))
m.ttlCache.OnEviction(m.nodeDown)
go m.ttlCache.Start()

listen, err := net.Listen("tcp", ":0")
assert.NoError(t, err)
m.addr <- listen.Addr().String()
creds, err := protocol.NewServerCreds(&protocol.TLSConfig{
SSLCA: o.SSLCA,
SSLCert: o.SSLCert,
SSLKey: o.SSLKey,
})
assert.NoError(t, err)
m.grpcServer = grpc.NewServer(grpc.Creds(creds))
protocol.RegisterMasterServer(m.grpcServer, m)
err = m.grpcServer.Serve(listen)
assert.NoError(t, err)
}

func (m *mockMasterRPC) Stop() {
m.grpcServer.Stop()
}
Expand Down Expand Up @@ -155,3 +178,64 @@ func TestRPC(t *testing.T) {

rpcServer.Stop()
}

func generateToTempFile(t *testing.T) (string, string, string) {
// Generate Certificate Authority
ca := testcerts.NewCA()
// Create a signed Certificate and Key
certs, err := ca.NewKeyPair()
assert.NoError(t, err)
// Write certificates to a file
caFile := filepath.Join(t.TempDir(), "ca.pem")
certFile := filepath.Join(t.TempDir(), "cert.pem")
keyFile := filepath.Join(t.TempDir(), "key.pem")
pem := ca.PublicKey()
err = os.WriteFile(caFile, pem, 0640)
assert.NoError(t, err)
err = certs.ToFile(certFile, keyFile)
assert.NoError(t, err)
return caFile, certFile, keyFile
}

func TestSSL(t *testing.T) {
caFile, certFile, keyFile := generateToTempFile(t)
o := &protocol.TLSConfig{
SSLCA: caFile,
SSLCert: certFile,
SSLKey: keyFile,
}
rpcServer := newMockMasterRPC(t)
go rpcServer.StartTLS(t, o)
address := <-rpcServer.addr

// success
c, err := protocol.NewClientCreds(o)
assert.NoError(t, err)
conn, err := grpc.Dial(address, grpc.WithTransportCredentials(c))
assert.NoError(t, err)
client := protocol.NewMasterClient(conn)
_, err = client.GetMeta(context.Background(), &protocol.NodeInfo{NodeType: protocol.NodeType_ServerNode, NodeName: "server1", HttpPort: 1234})
assert.NoError(t, err)

// insecure
conn, err = grpc.Dial(address, grpc.WithInsecure())
assert.NoError(t, err)
client = protocol.NewMasterClient(conn)
_, err = client.GetMeta(context.Background(), &protocol.NodeInfo{NodeType: protocol.NodeType_ServerNode, NodeName: "server1", HttpPort: 1234})
assert.Error(t, err)

// certificate mismatch
caFile2, certFile2, keyFile2 := generateToTempFile(t)
o2 := &protocol.TLSConfig{
SSLCA: caFile2,
SSLCert: certFile2,
SSLKey: keyFile2,
}
c, err = protocol.NewClientCreds(o2)
assert.NoError(t, err)
conn, err = grpc.Dial(address, grpc.WithTransportCredentials(c))
assert.NoError(t, err)
client = protocol.NewMasterClient(conn)
_, err = client.GetMeta(context.Background(), &protocol.NodeInfo{NodeType: protocol.NodeType_ServerNode, NodeName: "server1", HttpPort: 1234})
assert.Error(t, err)
}
Loading

0 comments on commit cca6da7

Please sign in to comment.