From ff6cb60dd258c1e4b984c4fa7da8a53740b403f9 Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 14 Sep 2023 16:43:01 +0800 Subject: [PATCH] fix Signed-off-by: husharp --- pkg/mcs/scheduling/server/config/config.go | 7 +++++++ pkg/replication/replication_mode.go | 4 ++-- pkg/schedule/config/config.go | 8 ++++++++ pkg/schedule/config/config_provider.go | 1 + pkg/utils/configutil/configutil.go | 1 + server/api/admin.go | 5 ++++- server/api/admin_test.go | 4 ++-- server/api/plugin.go | 8 ++++++++ server/api/server_test.go | 23 ++++++++++++++++++++++ server/config/persist_options.go | 7 +++++++ server/handler.go | 11 +++++++++++ server/server.go | 10 +++++++++- server/server_test.go | 13 ++++++++++++ server/util.go | 15 ++++++++++++++ 14 files changed, 111 insertions(+), 6 deletions(-) diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index 82c15632b3d5..29ada45512e8 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -606,6 +606,13 @@ func (o *PersistConfig) SetHaltScheduling(halt bool, source string) { o.SetScheduleConfig(v) } +// SetEnableSchedulePlugin set EnableSchedulePlugin. +func (o *PersistConfig) SetEnableSchedulePlugin(enable bool) { + v := o.GetScheduleConfig().Clone() + v.EnableSchedulePlugin = enable + o.SetScheduleConfig(v) +} + // CheckRegionKeys return error if the smallest region's keys is less than mergeKeys func (o *PersistConfig) CheckRegionKeys(keys, mergeKeys uint64) error { return o.GetStoreConfig().CheckRegionKeys(keys, mergeKeys) diff --git a/pkg/replication/replication_mode.go b/pkg/replication/replication_mode.go index 5a52f562e600..e56813f28a27 100644 --- a/pkg/replication/replication_mode.go +++ b/pkg/replication/replication_mode.go @@ -60,7 +60,7 @@ type FileReplicater interface { ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error } -const drStatusFile = "DR_STATE" +const DrStatusFile = "DR_STATE" const persistFileTimeout = time.Second * 3 // ModeManager is used to control how raft logs are synchronized between @@ -489,7 +489,7 @@ func (m *ModeManager) tickReplicateStatus() { stateID, ok := m.replicateState.Load(member.GetMemberId()) if !ok || stateID.(uint64) != state.StateID { ctx, cancel := context.WithTimeout(context.Background(), persistFileTimeout) - err := m.fileReplicater.ReplicateFileToMember(ctx, member, drStatusFile, data) + err := m.fileReplicater.ReplicateFileToMember(ctx, member, DrStatusFile, data) if err != nil { log.Warn("failed to switch state", zap.String("replicate-mode", modeDRAutoSync), zap.String("new-state", state.State), errs.ZapError(err)) } else { diff --git a/pkg/schedule/config/config.go b/pkg/schedule/config/config.go index c8fa62b8aff8..3c2d60f4b938 100644 --- a/pkg/schedule/config/config.go +++ b/pkg/schedule/config/config.go @@ -57,6 +57,7 @@ const ( defaultEnablePlacementRules = true defaultEnableWitness = false defaultHaltScheduling = false + defaultEnableSchedulePlugin = false defaultRegionScoreFormulaVersion = "v2" defaultLeaderSchedulePolicy = "count" @@ -269,6 +270,9 @@ type ScheduleConfig struct { // HaltScheduling is the option to halt the scheduling. Once it's on, PD will halt the scheduling, // and any other scheduling configs will be ignored. HaltScheduling bool `toml:"halt-scheduling" json:"halt-scheduling,string,omitempty"` + + // EnableSchedulePlugin is the option to enable plugin. + EnableSchedulePlugin bool `toml:"enable-schedule-plugin" json:"enable-schedule-plugin,string"` } // Clone returns a cloned scheduling configuration. @@ -367,6 +371,10 @@ func (c *ScheduleConfig) Adjust(meta *configutil.ConfigMetaData, reloading bool) c.HaltScheduling = defaultHaltScheduling } + if !meta.IsDefined("enable-schedule-plugin") { + c.EnableSchedulePlugin = defaultEnableSchedulePlugin + } + adjustSchedulers(&c.Schedulers, DefaultSchedulers) for k, b := range c.migrateConfigurationMap() { diff --git a/pkg/schedule/config/config_provider.go b/pkg/schedule/config/config_provider.go index 00f11a5950f1..f402de86408e 100644 --- a/pkg/schedule/config/config_provider.go +++ b/pkg/schedule/config/config_provider.go @@ -115,6 +115,7 @@ type SharedConfigProvider interface { IsWitnessAllowed() bool IsPlacementRulesCacheEnabled() bool SetHaltScheduling(bool, string) + SetEnableSchedulePlugin(bool) // for test purpose SetPlacementRulesCacheEnabled(bool) diff --git a/pkg/utils/configutil/configutil.go b/pkg/utils/configutil/configutil.go index 978edce77640..7a92ff1ebecc 100644 --- a/pkg/utils/configutil/configutil.go +++ b/pkg/utils/configutil/configutil.go @@ -81,6 +81,7 @@ type SecurityConfig struct { // RedactInfoLog indicates that whether enabling redact log RedactInfoLog bool `toml:"redact-info-log" json:"redact-info-log"` Encryption encryption.Config `toml:"encryption" json:"encryption"` + GrantPlugin bool `toml:"grant-plugin" json:"grant-plugin"` } // PrintConfigCheckMsg prints the message about configuration checks. diff --git a/server/api/admin.go b/server/api/admin.go index c81193f1468d..7a1dfb0f1e82 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -111,7 +111,10 @@ func (h *adminHandler) DeleteAllRegionCache(w http.ResponseWriter, r *http.Reque } // Intentionally no swagger mark as it is supposed to be only used in -// server-to-server. For security reason, it only accepts JSON formatted data. +// server-to-server. +// For security reason, +// - it only accepts JSON formatted data. +// - it only accepts file name which is `DrStatusFile`. func (h *adminHandler) SavePersistFile(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 1f2b386eb987..dde5387e826e 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -168,10 +168,10 @@ func (suite *adminTestSuite) TestDropRegions() { func (suite *adminTestSuite) TestPersistFile() { data := []byte("#!/bin/sh\nrm -rf /") re := suite.Require() - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(re)) + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/DR_STATE", data, tu.StatusNotOK(re)) suite.NoError(err) data = []byte(`{"foo":"bar"}`) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/DR_STATE", data, tu.StatusOK(re)) suite.NoError(err) } diff --git a/server/api/plugin.go b/server/api/plugin.go index fd75cc6bb2b8..041cfd5c06e9 100644 --- a/server/api/plugin.go +++ b/server/api/plugin.go @@ -48,6 +48,10 @@ func newPluginHandler(handler *server.Handler, rd *render.Render) *pluginHandler // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /plugin [post] func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) { + if !h.GetScheduleConfig().EnableSchedulePlugin { + h.rd.JSON(w, http.StatusInternalServerError, errors.New("load plugin failed, please enable plugin first")) + return + } h.processPluginCommand(w, r, schedule.PluginLoad) } @@ -62,6 +66,10 @@ func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /plugin [delete] func (h *pluginHandler) UnloadPlugin(w http.ResponseWriter, r *http.Request) { + if !h.GetScheduleConfig().EnableSchedulePlugin { + h.rd.JSON(w, http.StatusInternalServerError, errors.New("unload plugin failed, please enable plugin first")) + return + } h.processPluginCommand(w, r, schedule.PluginUnload) } diff --git a/server/api/server_test.go b/server/api/server_test.go index 88253b3a6242..2e89ad797c34 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -16,7 +16,9 @@ package api import ( "context" + "fmt" "net/http" + "net/http/httptest" "sort" "sync" "testing" @@ -210,3 +212,24 @@ func (suite *serviceTestSuite) TestServiceLabels() { apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodGet)) suite.Equal("QueryMetric", serviceLabel) } + +func (suite *adminTestSuite) TestCleanPath() { + re := suite.Require() + // transfer path to /config + url := fmt.Sprintf("%s/admin/persist-file/../../config", suite.urlPrefix) + cfg := &config.Config{} + err := testutil.ReadGetJSON(re, testDialClient, url, cfg) + suite.NoError(err) + + // handled by router + response := httptest.NewRecorder() + r, _, _ := NewHandler(context.Background(), suite.svr) + request, err := http.NewRequest(http.MethodGet, url, nil) + re.NoError(err) + r.ServeHTTP(response, request) + // handled by `cleanPath` which is in `mux.ServeHTTP` + result := response.Result() + defer result.Body.Close() + re.NotNil(result.Header["Location"]) + re.Contains(result.Header["Location"][0], "/pd/api/v1/config") +} diff --git a/server/config/persist_options.go b/server/config/persist_options.go index 1ea0b79424f8..4c8b7ed2e9fe 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -982,6 +982,13 @@ func (o *PersistOptions) IsSchedulingHalted() bool { return o.GetScheduleConfig().HaltScheduling } +// SetEnableSchedulePlugin to set the option for witness. It's only used to test. +func (o *PersistOptions) SetEnableSchedulePlugin(enable bool) { + v := o.GetScheduleConfig().Clone() + v.EnableSchedulePlugin = enable + o.SetScheduleConfig(v) +} + // GetRegionMaxSize returns the max region size in MB func (o *PersistOptions) GetRegionMaxSize() uint64 { return o.GetStoreConfig().GetRegionMaxSize() diff --git a/server/handler.go b/server/handler.go index a90f8e3f04f3..e3bf8f20ada5 100644 --- a/server/handler.go +++ b/server/handler.go @@ -21,6 +21,7 @@ import ( "net/http" "net/url" "path" + "path/filepath" "strconv" "strings" "time" @@ -997,6 +998,16 @@ func (h *Handler) PluginLoad(pluginPath string) error { c := cluster.GetCoordinator() ch := make(chan string) h.pluginChMap[pluginPath] = ch + + // make sure path is in data dir + filePath, err := filepath.Abs(pluginPath) + if err != nil { + return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + } + if !isPathInDirectory(filePath, h.s.GetConfig().DataDir) { + return errs.ErrLoadPlugin.Wrap(err).FastGenWithCause() + } + c.LoadPlugin(pluginPath, ch) return nil } diff --git a/server/server.go b/server/server.go index 7c19d8ff7c5c..8faeceb0acf4 100644 --- a/server/server.go +++ b/server/server.go @@ -19,6 +19,7 @@ import ( "context" errorspkg "errors" "fmt" + "github.com/tikv/pd/pkg/replication" "math/rand" "net/http" "os" @@ -1868,8 +1869,15 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, // PersistFile saves a file in DataDir. func (s *Server) PersistFile(name string, data []byte) error { + if name != replication.DrStatusFile { + return errors.New("Invalid file name") + } log.Info("persist file", zap.String("name", name), zap.Binary("data", data)) - return os.WriteFile(filepath.Join(s.GetConfig().DataDir, name), data, 0644) // #nosec + path := filepath.Join(s.GetConfig().DataDir, name) + if !isPathInDirectory(path, s.GetConfig().DataDir) { + return errors.New("Invalid file path") + } + return os.WriteFile(path, data, 0644) // #nosec } // SaveTTLConfig save ttl config diff --git a/server/server_test.go b/server/server_test.go index 2d0e23c7682c..62cf5b168fc7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/http" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -307,3 +308,15 @@ func TestAPIService(t *testing.T) { MustWaitLeader(re, []*Server{svr}) re.True(svr.IsAPIServiceMode()) } + +func TestIsPathInDirectory(t *testing.T) { + re := require.New(t) + fileName := "test" + directory := "/root/project" + path := filepath.Join(directory, fileName) + re.True(isPathInDirectory(path, directory)) + + fileName = "../../test" + path = filepath.Join(directory, fileName) + re.False(isPathInDirectory(path, directory)) +} diff --git a/server/util.go b/server/util.go index 9c7a97a98066..654b424465e3 100644 --- a/server/util.go +++ b/server/util.go @@ -17,6 +17,7 @@ package server import ( "context" "net/http" + "path/filepath" "strings" "github.com/gorilla/mux" @@ -124,3 +125,17 @@ func combineBuilderServerHTTPService(ctx context.Context, svr *Server, serviceBu userHandlers[pdAPIPrefix] = apiService return userHandlers, nil } + +func isPathInDirectory(path, directory string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + + absDir, err := filepath.Abs(directory) + if err != nil { + return false + } + + return strings.HasPrefix(absPath, absDir) +}