Skip to content

Commit

Permalink
ctl: fix https client panic (#8239) (#8247)
Browse files Browse the repository at this point in the history
ref #7300, close #8237

fix panic when call pd-ctl cluster with tls

Signed-off-by: ti-chi-bot <[email protected]>
Signed-off-by: okJiang <[email protected]>

Co-authored-by: okJiang <[email protected]>
Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent f347d6e commit 632ee6e
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 44 deletions.
72 changes: 28 additions & 44 deletions tools/pd-ctl/pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,15 @@ var PDCli pd.Client

func requirePDClient(cmd *cobra.Command, _ []string) error {
var (
caPath string
err error
tlsConfig *tls.Config
err error
)
caPath, err = cmd.Flags().GetString("cacert")
if err == nil && len(caPath) != 0 {
var certPath, keyPath string
certPath, err = cmd.Flags().GetString("cert")
if err != nil {
return err
}
keyPath, err = cmd.Flags().GetString("key")
if err != nil {
return err
}
return initNewPDClientWithTLS(cmd, caPath, certPath, keyPath)
tlsConfig, err = parseTLSConfig(cmd)
if err != nil {
return err
}
return initNewPDClient(cmd)

return initNewPDClient(cmd, pd.WithTLSConfig(tlsConfig))
}

// shouldInitPDClient checks whether we should create a new PD client according to the cluster information.
Expand Down Expand Up @@ -111,44 +103,36 @@ func initNewPDClient(cmd *cobra.Command, opts ...pd.ClientOption) error {
return nil
}

func initNewPDClientWithTLS(cmd *cobra.Command, caPath, certPath, keyPath string) error {
tlsConfig, err := initTLSConfig(caPath, certPath, keyPath)
if err != nil {
return err
}
initNewPDClient(cmd, pd.WithTLSConfig(tlsConfig))
return nil
}

// TODO: replace dialClient with the PD HTTP client completely.
var dialClient = &http.Client{
Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID),
}

// RequireHTTPSClient creates a HTTPS client if the related flags are set
func RequireHTTPSClient(cmd *cobra.Command, args []string) error {
func parseTLSConfig(cmd *cobra.Command) (*tls.Config, error) {
caPath, err := cmd.Flags().GetString("cacert")
if err == nil && len(caPath) != 0 {
certPath, err := cmd.Flags().GetString("cert")
if err != nil {
return err
}
keyPath, err := cmd.Flags().GetString("key")
if err != nil {
return err
}
err = initHTTPSClient(caPath, certPath, keyPath)
if err != nil {
cmd.Println(err)
return err
}
if err != nil || len(caPath) == 0 {
return nil, err
}
certPath, err := cmd.Flags().GetString("cert")
if err != nil {
return nil, err
}
keyPath, err := cmd.Flags().GetString("key")
if err != nil {
return nil, err
}
return nil
}

func initHTTPSClient(caPath, certPath, keyPath string) error {
tlsConfig, err := initTLSConfig(caPath, certPath, keyPath)
if err != nil {
return nil, err
}

return tlsConfig, nil
}

// RequireHTTPSClient creates a HTTPS client if the related flags are set
func RequireHTTPSClient(cmd *cobra.Command, _ []string) error {
tlsConfig, err := parseTLSConfig(cmd)
if err != nil || tlsConfig == nil {
return err
}
dialClient = &http.Client{
Expand Down
58 changes: 58 additions & 0 deletions tools/pd-ctl/pdctl/command/global_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package command

import (
"os"
"os/exec"
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)

func TestParseTLSConfig(t *testing.T) {
re := require.New(t)

rootCmd := &cobra.Command{
Use: "pd-ctl",
Short: "Placement Driver control",
SilenceErrors: true,
}
certPath := "../../tests/cert"
rootCmd.Flags().String("cacert", certPath+"/ca.pem", "path of file that contains list of trusted SSL CAs")
rootCmd.Flags().String("cert", certPath+"/client.pem", "path of file that contains X509 certificate in PEM format")
rootCmd.Flags().String("key", certPath+"/client-key.pem", "path of file that contains X509 key in PEM format")

// generate certs
if err := os.Mkdir(certPath, 0755); err != nil {
t.Fatal(err)
}
certScript := "../../tests/cert_opt.sh"
if err := exec.Command(certScript, "generate", certPath).Run(); err != nil {
t.Fatal(err)
}
defer func() {
if err := exec.Command(certScript, "cleanup", certPath).Run(); err != nil {
t.Fatal(err)
}
if err := os.RemoveAll(certPath); err != nil {
t.Fatal(err)
}
}()

tlsConfig, err := parseTLSConfig(rootCmd)
re.NoError(err)
re.NotNil(tlsConfig)
}
1 change: 1 addition & 0 deletions tools/pd-ctl/pdctl/ctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

func init() {
cobra.EnablePrefixMatching = true
cobra.EnableTraverseRunHooks = true
}

// GetRootCmd is exposed for integration tests. But it can be embedded into another suite, too.
Expand Down
84 changes: 84 additions & 0 deletions tools/pd-ctl/tests/health/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@ package health_test
import (
"context"
"encoding/json"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/require"
"github.com/tikv/pd/pkg/utils/grpcutil"
"github.com/tikv/pd/server/api"
"github.com/tikv/pd/server/cluster"
"github.com/tikv/pd/server/config"
pdTests "github.com/tikv/pd/tests"
ctl "github.com/tikv/pd/tools/pd-ctl/pdctl"
"github.com/tikv/pd/tools/pd-ctl/tests"
"go.etcd.io/etcd/pkg/transport"
)

func TestHealth(t *testing.T) {
Expand Down Expand Up @@ -68,3 +75,80 @@ func TestHealth(t *testing.T) {
re.NoError(json.Unmarshal(output, &h))
re.Equal(healths, h)
}

func TestHealthTLS(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
certPath := "../cert"
certScript := "../cert_opt.sh"
// generate certs
if err := os.Mkdir(certPath, 0755); err != nil {
t.Fatal(err)
}
if err := exec.Command(certScript, "generate", certPath).Run(); err != nil {
t.Fatal(err)
}
defer func() {
if err := exec.Command(certScript, "cleanup", certPath).Run(); err != nil {
t.Fatal(err)
}
if err := os.RemoveAll(certPath); err != nil {
t.Fatal(err)
}
}()

tlsInfo := transport.TLSInfo{
KeyFile: filepath.Join(certPath, "pd-server-key.pem"),
CertFile: filepath.Join(certPath, "pd-server.pem"),
TrustedCAFile: filepath.Join(certPath, "ca.pem"),
}
tc, err := pdTests.NewTestCluster(ctx, 1, func(conf *config.Config, _ string) {
conf.Security.TLSConfig = grpcutil.TLSConfig{
KeyPath: tlsInfo.KeyFile,
CertPath: tlsInfo.CertFile,
CAPath: tlsInfo.TrustedCAFile,
}
conf.AdvertiseClientUrls = strings.ReplaceAll(conf.AdvertiseClientUrls, "http", "https")
conf.ClientUrls = strings.ReplaceAll(conf.ClientUrls, "http", "https")
conf.AdvertisePeerUrls = strings.ReplaceAll(conf.AdvertisePeerUrls, "http", "https")
conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https")
conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https")
})
re.NoError(err)
defer tc.Destroy()
err = tc.RunInitialServers()
re.NoError(err)
tc.WaitLeader()
cmd := ctl.GetRootCmd()

client := tc.GetEtcdClient()
members, err := cluster.GetMembers(client)
re.NoError(err)
healthMembers := cluster.CheckHealth(tc.GetHTTPClient(), members)
healths := []api.Health{}
for _, member := range members {
h := api.Health{
Name: member.Name,
MemberID: member.MemberId,
ClientUrls: member.ClientUrls,
Health: false,
}
if _, ok := healthMembers[member.GetMemberId()]; ok {
h.Health = true
}
healths = append(healths, h)
}

pdAddr := tc.GetConfig().GetClientURL()
pdAddr = strings.ReplaceAll(pdAddr, "http", "https")
args := []string{"-u", pdAddr, "health",
"--cacert=../cert/ca.pem",
"--cert=../cert/client.pem",
"--key=../cert/client-key.pem"}
output, err := tests.ExecuteCommand(cmd, args...)
re.NoError(err)
h := make([]api.Health, len(healths))
re.NoError(json.Unmarshal(output, &h))
re.Equal(healths, h)
}

0 comments on commit 632ee6e

Please sign in to comment.