diff --git a/Makefile b/Makefile index 99d6d680a7a..c874586ba5f 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,10 @@ ifeq ("$(WITH_RACE)", "1") BUILD_CGO_ENABLED := 1 endif +ifeq ($(PLUGIN), 1) + BUILD_TAGS += with_plugin +endif + LDFLAGS += -X "$(PD_PKG)/server/versioninfo.PDReleaseVersion=$(shell git describe --tags --dirty --always)" LDFLAGS += -X "$(PD_PKG)/server/versioninfo.PDBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" LDFLAGS += -X "$(PD_PKG)/server/versioninfo.PDGitHash=$(shell git rev-parse HEAD)" diff --git a/server/api/admin.go b/server/api/admin.go index 93dffbd66c5..f29af732bd8 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -133,7 +133,10 @@ func (h *adminHandler) ResetTS(w http.ResponseWriter, r *http.Request) { } // Intentionally no swagger mark as it is supposed to be only used in -// server-to-server. For security reason, it only accepts JSON formatted data. +// server-to-server. +// For security reason, +// - it only accepts JSON formatted data. +// - it only accepts file name which is `DrStatusFile`. func (h *adminHandler) SavePersistFile(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 6051edf1805..962252dadfa 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -29,6 +29,7 @@ import ( tu "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/replication" ) type adminTestSuite struct { @@ -168,10 +169,10 @@ func (suite *adminTestSuite) TestDropRegions() { func (suite *adminTestSuite) TestPersistFile() { data := []byte("#!/bin/sh\nrm -rf /") re := suite.Require() - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(re)) + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusNotOK(re)) suite.NoError(err) data = []byte(`{"foo":"bar"}`) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusOK(re)) suite.NoError(err) } diff --git a/server/api/plugin.go b/server/api/plugin.go index 16894304e9b..8727af02115 100644 --- a/server/api/plugin.go +++ b/server/api/plugin.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build with_plugin +// +build with_plugin + package api import ( diff --git a/server/api/plugin_disable.go b/server/api/plugin_disable.go new file mode 100644 index 00000000000..9c788c6dc48 --- /dev/null +++ b/server/api/plugin_disable.go @@ -0,0 +1,41 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !with_plugin +// +build !with_plugin + +package api + +import ( + "net/http" + + "github.com/tikv/pd/server" + "github.com/unrolled/render" +) + +type pluginHandler struct{} + +func newPluginHandler(_ *server.Handler, _ *render.Render) *pluginHandler { + return &pluginHandler{} +} + +func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("load plugin is disabled, please `PLUGIN=1 MAKE pd-server` first")) +} + +func (h *pluginHandler) UnloadPlugin(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("unload plugin is disabled, please `PLUGIN=1 MAKE pd-server` first")) +} diff --git a/server/api/server_test.go b/server/api/server_test.go index 8693b4b87ca..bdf929b17b3 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -16,7 +16,9 @@ package api import ( "context" + "fmt" "net/http" + "net/http/httptest" "sort" "sync" "testing" @@ -211,3 +213,24 @@ func (suite *serviceTestSuite) TestServiceLabels() { apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodGet)) suite.Equal("QueryMetric", serviceLabel) } + +func (suite *adminTestSuite) TestCleanPath() { + re := suite.Require() + // transfer path to /config + url := fmt.Sprintf("%s/admin/persist-file/../../config", suite.urlPrefix) + cfg := &config.Config{} + err := testutil.ReadGetJSON(re, testDialClient, url, cfg) + suite.NoError(err) + + // handled by router + response := httptest.NewRecorder() + r, _, _ := NewHandler(context.Background(), suite.svr) + request, err := http.NewRequest(http.MethodGet, url, nil) + re.NoError(err) + r.ServeHTTP(response, request) + // handled by `cleanPath` which is in `mux.ServeHTTP` + result := response.Result() + defer result.Body.Close() + re.NotNil(result.Header["Location"]) + re.Contains(result.Header["Location"][0], "/pd/api/v1/config") +} diff --git a/server/handler.go b/server/handler.go index a871b19f447..f99b01abf99 100644 --- a/server/handler.go +++ b/server/handler.go @@ -21,6 +21,7 @@ import ( "fmt" "net/http" "path" + "path/filepath" "strconv" "strings" "time" @@ -967,6 +968,13 @@ func (h *Handler) PluginLoad(pluginPath string) error { c := cluster.GetCoordinator() ch := make(chan string) h.pluginChMap[pluginPath] = ch + + // make sure path is in data dir + filePath, err := filepath.Abs(pluginPath) + if err != nil || !isPathInDirectory(filePath, h.s.GetConfig().DataDir) { + return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + } + c.LoadPlugin(pluginPath, ch) return nil } diff --git a/server/replication/replication_mode.go b/server/replication/replication_mode.go index d276bd8ec18..4d54365d290 100644 --- a/server/replication/replication_mode.go +++ b/server/replication/replication_mode.go @@ -60,7 +60,8 @@ type FileReplicater interface { ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error } -const drStatusFile = "DR_STATE" +// DrStatusFile is the file name that stores the dr status. +const DrStatusFile = "DR_STATE" const persistFileTimeout = time.Second * 3 // ModeManager is used to control how raft logs are synchronized between @@ -483,7 +484,7 @@ func (m *ModeManager) tickReplicateStatus() { stateID, ok := m.replicateState.Load(member.GetMemberId()) if !ok || stateID.(uint64) != state.StateID { ctx, cancel := context.WithTimeout(context.Background(), persistFileTimeout) - err := m.fileReplicater.ReplicateFileToMember(ctx, member, drStatusFile, data) + err := m.fileReplicater.ReplicateFileToMember(ctx, member, DrStatusFile, data) if err != nil { log.Warn("failed to switch state", zap.String("replicate-mode", modeDRAutoSync), zap.String("new-state", state.State), errs.ZapError(err)) } else { diff --git a/server/server.go b/server/server.go index 37d0846876c..adaaf5c3391 100644 --- a/server/server.go +++ b/server/server.go @@ -59,6 +59,7 @@ import ( "github.com/tikv/pd/server/keyspace" "github.com/tikv/pd/server/member" syncer "github.com/tikv/pd/server/region_syncer" + "github.com/tikv/pd/server/replication" "github.com/tikv/pd/server/schedule" "github.com/tikv/pd/server/schedule/hbstream" "github.com/tikv/pd/server/schedule/placement" @@ -1678,8 +1679,15 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, // PersistFile saves a file in DataDir. func (s *Server) PersistFile(name string, data []byte) error { + if name != replication.DrStatusFile { + return errors.New("Invalid file name") + } log.Info("persist file", zap.String("name", name), zap.Binary("data", data)) - return os.WriteFile(filepath.Join(s.GetConfig().DataDir, name), data, 0644) // #nosec + path := filepath.Join(s.GetConfig().DataDir, name) + if !isPathInDirectory(path, s.GetConfig().DataDir) { + return errors.New("Invalid file path") + } + return os.WriteFile(path, data, 0644) // #nosec } // SaveTTLConfig save ttl config diff --git a/server/server_test.go b/server/server_test.go index cd7cf89ab97..2f9b04767bc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/http" + "path/filepath" "testing" "github.com/stretchr/testify/suite" @@ -337,3 +338,14 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() { bodyString := string(bodyBytes) suite.Equal("Hello World\n", bodyString) } + +func (suite *leaderServerTestSuite) TestIsPathInDirectory() { + fileName := "test" + directory := "/root/project" + path := filepath.Join(directory, fileName) + suite.True(isPathInDirectory(path, directory)) + + fileName = "../../test" + path = filepath.Join(directory, fileName) + suite.False(isPathInDirectory(path, directory)) +} diff --git a/server/util.go b/server/util.go index 53f4be1c666..6a060779d70 100644 --- a/server/util.go +++ b/server/util.go @@ -18,6 +18,8 @@ import ( "context" "fmt" "math/rand" + "path/filepath" + "strings" "time" "github.com/pingcap/errors" @@ -157,3 +159,17 @@ func checkBootstrapRequest(clusterID uint64, req *pdpb.BootstrapRequest) error { return nil } + +func isPathInDirectory(path, directory string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + + absDir, err := filepath.Abs(directory) + if err != nil { + return false + } + + return strings.HasPrefix(absPath, absDir) +}