Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: husharp <[email protected]>
  • Loading branch information
HuSharp committed Sep 14, 2023
1 parent e295e62 commit ff6cb60
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 6 deletions.
7 changes: 7 additions & 0 deletions pkg/mcs/scheduling/server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,13 @@ func (o *PersistConfig) SetHaltScheduling(halt bool, source string) {
o.SetScheduleConfig(v)
}

// SetEnableSchedulePlugin set EnableSchedulePlugin.
func (o *PersistConfig) SetEnableSchedulePlugin(enable bool) {
v := o.GetScheduleConfig().Clone()
v.EnableSchedulePlugin = enable
o.SetScheduleConfig(v)
}

// CheckRegionKeys return error if the smallest region's keys is less than mergeKeys
func (o *PersistConfig) CheckRegionKeys(keys, mergeKeys uint64) error {
return o.GetStoreConfig().CheckRegionKeys(keys, mergeKeys)
Expand Down
4 changes: 2 additions & 2 deletions pkg/replication/replication_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type FileReplicater interface {
ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error
}

const drStatusFile = "DR_STATE"
const DrStatusFile = "DR_STATE"
const persistFileTimeout = time.Second * 3

// ModeManager is used to control how raft logs are synchronized between
Expand Down Expand Up @@ -489,7 +489,7 @@ func (m *ModeManager) tickReplicateStatus() {
stateID, ok := m.replicateState.Load(member.GetMemberId())
if !ok || stateID.(uint64) != state.StateID {
ctx, cancel := context.WithTimeout(context.Background(), persistFileTimeout)
err := m.fileReplicater.ReplicateFileToMember(ctx, member, drStatusFile, data)
err := m.fileReplicater.ReplicateFileToMember(ctx, member, DrStatusFile, data)
if err != nil {
log.Warn("failed to switch state", zap.String("replicate-mode", modeDRAutoSync), zap.String("new-state", state.State), errs.ZapError(err))
} else {
Expand Down
8 changes: 8 additions & 0 deletions pkg/schedule/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const (
defaultEnablePlacementRules = true
defaultEnableWitness = false
defaultHaltScheduling = false
defaultEnableSchedulePlugin = false

defaultRegionScoreFormulaVersion = "v2"
defaultLeaderSchedulePolicy = "count"
Expand Down Expand Up @@ -269,6 +270,9 @@ type ScheduleConfig struct {
// HaltScheduling is the option to halt the scheduling. Once it's on, PD will halt the scheduling,
// and any other scheduling configs will be ignored.
HaltScheduling bool `toml:"halt-scheduling" json:"halt-scheduling,string,omitempty"`

// EnableSchedulePlugin is the option to enable plugin.
EnableSchedulePlugin bool `toml:"enable-schedule-plugin" json:"enable-schedule-plugin,string"`
}

// Clone returns a cloned scheduling configuration.
Expand Down Expand Up @@ -367,6 +371,10 @@ func (c *ScheduleConfig) Adjust(meta *configutil.ConfigMetaData, reloading bool)
c.HaltScheduling = defaultHaltScheduling
}

if !meta.IsDefined("enable-schedule-plugin") {
c.EnableSchedulePlugin = defaultEnableSchedulePlugin
}

adjustSchedulers(&c.Schedulers, DefaultSchedulers)

for k, b := range c.migrateConfigurationMap() {
Expand Down
1 change: 1 addition & 0 deletions pkg/schedule/config/config_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ type SharedConfigProvider interface {
IsWitnessAllowed() bool
IsPlacementRulesCacheEnabled() bool
SetHaltScheduling(bool, string)
SetEnableSchedulePlugin(bool)

// for test purpose
SetPlacementRulesCacheEnabled(bool)
Expand Down
1 change: 1 addition & 0 deletions pkg/utils/configutil/configutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type SecurityConfig struct {
// RedactInfoLog indicates that whether enabling redact log
RedactInfoLog bool `toml:"redact-info-log" json:"redact-info-log"`
Encryption encryption.Config `toml:"encryption" json:"encryption"`
GrantPlugin bool `toml:"grant-plugin" json:"grant-plugin"`
}

// PrintConfigCheckMsg prints the message about configuration checks.
Expand Down
5 changes: 4 additions & 1 deletion server/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ func (h *adminHandler) DeleteAllRegionCache(w http.ResponseWriter, r *http.Reque
}

// Intentionally no swagger mark as it is supposed to be only used in
// server-to-server. For security reason, it only accepts JSON formatted data.
// server-to-server.
// For security reason,
// - it only accepts JSON formatted data.
// - it only accepts file name which is `DrStatusFile`.
func (h *adminHandler) SavePersistFile(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions server/api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ func (suite *adminTestSuite) TestDropRegions() {
func (suite *adminTestSuite) TestPersistFile() {
data := []byte("#!/bin/sh\nrm -rf /")
re := suite.Require()
err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(re))
err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/DR_STATE", data, tu.StatusNotOK(re))
suite.NoError(err)
data = []byte(`{"foo":"bar"}`)
err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(re))
err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/DR_STATE", data, tu.StatusOK(re))
suite.NoError(err)
}

Expand Down
8 changes: 8 additions & 0 deletions server/api/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ func newPluginHandler(handler *server.Handler, rd *render.Render) *pluginHandler
// @Failure 500 {string} string "PD server failed to proceed the request."
// @Router /plugin [post]
func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) {
if !h.GetScheduleConfig().EnableSchedulePlugin {
h.rd.JSON(w, http.StatusInternalServerError, errors.New("load plugin failed, please enable plugin first"))
return
}
h.processPluginCommand(w, r, schedule.PluginLoad)
}

Expand All @@ -62,6 +66,10 @@ func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {string} string "PD server failed to proceed the request."
// @Router /plugin [delete]
func (h *pluginHandler) UnloadPlugin(w http.ResponseWriter, r *http.Request) {
if !h.GetScheduleConfig().EnableSchedulePlugin {
h.rd.JSON(w, http.StatusInternalServerError, errors.New("unload plugin failed, please enable plugin first"))
return
}
h.processPluginCommand(w, r, schedule.PluginUnload)
}

Expand Down
23 changes: 23 additions & 0 deletions server/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package api

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"sort"
"sync"
"testing"
Expand Down Expand Up @@ -210,3 +212,24 @@ func (suite *serviceTestSuite) TestServiceLabels() {
apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodGet))
suite.Equal("QueryMetric", serviceLabel)
}

func (suite *adminTestSuite) TestCleanPath() {
re := suite.Require()
// transfer path to /config
url := fmt.Sprintf("%s/admin/persist-file/../../config", suite.urlPrefix)
cfg := &config.Config{}
err := testutil.ReadGetJSON(re, testDialClient, url, cfg)
suite.NoError(err)

// handled by router
response := httptest.NewRecorder()
r, _, _ := NewHandler(context.Background(), suite.svr)
request, err := http.NewRequest(http.MethodGet, url, nil)
re.NoError(err)
r.ServeHTTP(response, request)
// handled by `cleanPath` which is in `mux.ServeHTTP`
result := response.Result()
defer result.Body.Close()
re.NotNil(result.Header["Location"])
re.Contains(result.Header["Location"][0], "/pd/api/v1/config")
}
7 changes: 7 additions & 0 deletions server/config/persist_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,13 @@ func (o *PersistOptions) IsSchedulingHalted() bool {
return o.GetScheduleConfig().HaltScheduling
}

// SetEnableSchedulePlugin to set the option for witness. It's only used to test.
func (o *PersistOptions) SetEnableSchedulePlugin(enable bool) {
v := o.GetScheduleConfig().Clone()
v.EnableSchedulePlugin = enable
o.SetScheduleConfig(v)
}

// GetRegionMaxSize returns the max region size in MB
func (o *PersistOptions) GetRegionMaxSize() uint64 {
return o.GetStoreConfig().GetRegionMaxSize()
Expand Down
11 changes: 11 additions & 0 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/http"
"net/url"
"path"
"path/filepath"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -997,6 +998,16 @@ func (h *Handler) PluginLoad(pluginPath string) error {
c := cluster.GetCoordinator()
ch := make(chan string)
h.pluginChMap[pluginPath] = ch

// make sure path is in data dir
filePath, err := filepath.Abs(pluginPath)
if err != nil {
return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause()
}
if !isPathInDirectory(filePath, h.s.GetConfig().DataDir) {
return errs.ErrLoadPlugin.Wrap(err).FastGenWithCause()
}

c.LoadPlugin(pluginPath, ch)
return nil
}
Expand Down
10 changes: 9 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
errorspkg "errors"
"fmt"
"github.com/tikv/pd/pkg/replication"
"math/rand"
"net/http"
"os"
Expand Down Expand Up @@ -1868,8 +1869,15 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member,

// PersistFile saves a file in DataDir.
func (s *Server) PersistFile(name string, data []byte) error {
if name != replication.DrStatusFile {
return errors.New("Invalid file name")
}
log.Info("persist file", zap.String("name", name), zap.Binary("data", data))
return os.WriteFile(filepath.Join(s.GetConfig().DataDir, name), data, 0644) // #nosec
path := filepath.Join(s.GetConfig().DataDir, name)
if !isPathInDirectory(path, s.GetConfig().DataDir) {
return errors.New("Invalid file path")
}
return os.WriteFile(path, data, 0644) // #nosec
}

// SaveTTLConfig save ttl config
Expand Down
13 changes: 13 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"io"
"net/http"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -307,3 +308,15 @@ func TestAPIService(t *testing.T) {
MustWaitLeader(re, []*Server{svr})
re.True(svr.IsAPIServiceMode())
}

func TestIsPathInDirectory(t *testing.T) {
re := require.New(t)
fileName := "test"
directory := "/root/project"
path := filepath.Join(directory, fileName)
re.True(isPathInDirectory(path, directory))

fileName = "../../test"
path = filepath.Join(directory, fileName)
re.False(isPathInDirectory(path, directory))
}
15 changes: 15 additions & 0 deletions server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package server
import (
"context"
"net/http"
"path/filepath"
"strings"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -124,3 +125,17 @@ func combineBuilderServerHTTPService(ctx context.Context, svr *Server, serviceBu
userHandlers[pdAPIPrefix] = apiService
return userHandlers, nil
}

func isPathInDirectory(path, directory string) bool {
absPath, err := filepath.Abs(path)
if err != nil {
return false
}

absDir, err := filepath.Abs(directory)
if err != nil {
return false
}

return strings.HasPrefix(absPath, absDir)
}

0 comments on commit ff6cb60

Please sign in to comment.