From 905d8ff01012279ec74d01f3232dbed68c28f46b Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Thu, 28 Sep 2023 15:43:50 +0800 Subject: [PATCH] security: disable plugin in default and persist file in specified dir (#7087) (#7142) close tikv/pd#7094 Signed-off-by: husharp Co-authored-by: husharp Co-authored-by: Hu# Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- Makefile | 4 +++ server/api/admin.go | 5 +++- server/api/admin_test.go | 5 ++-- server/api/plugin.go | 3 ++ server/api/plugin_disable.go | 41 ++++++++++++++++++++++++++ server/api/server_test.go | 23 +++++++++++++++ server/handler.go | 8 +++++ server/replication/replication_mode.go | 5 ++-- server/server.go | 10 ++++++- server/server_test.go | 13 ++++++++ server/util.go | 15 ++++++++++ 11 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 server/api/plugin_disable.go diff --git a/Makefile b/Makefile index c00cfa83b12..240dd5295a5 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,10 @@ ifeq ("$(WITH_RACE)", "1") BUILD_CGO_ENABLED := 1 endif +ifeq ($(PLUGIN), 1) + BUILD_TAGS += with_plugin +endif + LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDReleaseVersion=$(shell git describe --tags --dirty --always)" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDGitHash=$(shell git rev-parse HEAD)" diff --git a/server/api/admin.go b/server/api/admin.go index 334d1882a66..c6cddec4d64 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -71,7 +71,10 @@ func (h *adminHandler) DeleteAllRegionCache(w http.ResponseWriter, r *http.Reque } // Intentionally no swagger mark as it is supposed to be only used in -// server-to-server. For security reason, it only accepts JSON formatted data. +// server-to-server. +// For security reason, +// - it only accepts JSON formatted data. +// - it only accepts file name which is `DrStatusFile`. func (h *adminHandler) SavePersistFile(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 1f2b386eb98..b3fc01951dd 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -29,6 +29,7 @@ import ( "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" + "github.com/tikv/pd/server/replication" ) type adminTestSuite struct { @@ -168,10 +169,10 @@ func (suite *adminTestSuite) TestDropRegions() { func (suite *adminTestSuite) TestPersistFile() { data := []byte("#!/bin/sh\nrm -rf /") re := suite.Require() - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(re)) + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusNotOK(re)) suite.NoError(err) data = []byte(`{"foo":"bar"}`) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusOK(re)) suite.NoError(err) } diff --git a/server/api/plugin.go b/server/api/plugin.go index 192310cca7e..cc0b0ae6c5f 100644 --- a/server/api/plugin.go +++ b/server/api/plugin.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build with_plugin +// +build with_plugin + package api import ( diff --git a/server/api/plugin_disable.go b/server/api/plugin_disable.go new file mode 100644 index 00000000000..2676dbb91e2 --- /dev/null +++ b/server/api/plugin_disable.go @@ -0,0 +1,41 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !with_plugin +// +build !with_plugin + +package api + +import ( + "net/http" + + "github.com/tikv/pd/server" + "github.com/unrolled/render" +) + +type pluginHandler struct{} + +func newPluginHandler(_ *server.Handler, _ *render.Render) *pluginHandler { + return &pluginHandler{} +} + +func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("load plugin is disabled, please `PLUGIN=1 $(MAKE) pd-server` first")) +} + +func (h *pluginHandler) UnloadPlugin(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("unload plugin is disabled, please `PLUGIN=1 $(MAKE) pd-server` first")) +} diff --git a/server/api/server_test.go b/server/api/server_test.go index 88253b3a624..2e89ad797c3 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -16,7 +16,9 @@ package api import ( "context" + "fmt" "net/http" + "net/http/httptest" "sort" "sync" "testing" @@ -210,3 +212,24 @@ func (suite *serviceTestSuite) TestServiceLabels() { apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodGet)) suite.Equal("QueryMetric", serviceLabel) } + +func (suite *adminTestSuite) TestCleanPath() { + re := suite.Require() + // transfer path to /config + url := fmt.Sprintf("%s/admin/persist-file/../../config", suite.urlPrefix) + cfg := &config.Config{} + err := testutil.ReadGetJSON(re, testDialClient, url, cfg) + suite.NoError(err) + + // handled by router + response := httptest.NewRecorder() + r, _, _ := NewHandler(context.Background(), suite.svr) + request, err := http.NewRequest(http.MethodGet, url, nil) + re.NoError(err) + r.ServeHTTP(response, request) + // handled by `cleanPath` which is in `mux.ServeHTTP` + result := response.Result() + defer result.Body.Close() + re.NotNil(result.Header["Location"]) + re.Contains(result.Header["Location"][0], "/pd/api/v1/config") +} diff --git a/server/handler.go b/server/handler.go index 248ab46e022..19597fda1cc 100644 --- a/server/handler.go +++ b/server/handler.go @@ -21,6 +21,7 @@ import ( "fmt" "net/http" "path" + "path/filepath" "strconv" "strings" "time" @@ -980,6 +981,13 @@ func (h *Handler) PluginLoad(pluginPath string) error { c := cluster.GetCoordinator() ch := make(chan string) h.pluginChMap[pluginPath] = ch + + // make sure path is in data dir + filePath, err := filepath.Abs(pluginPath) + if err != nil || !isPathInDirectory(filePath, h.s.GetConfig().DataDir) { + return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + } + c.LoadPlugin(pluginPath, ch) return nil } diff --git a/server/replication/replication_mode.go b/server/replication/replication_mode.go index f1933db16ca..f1e8b5a9c8a 100644 --- a/server/replication/replication_mode.go +++ b/server/replication/replication_mode.go @@ -61,7 +61,8 @@ type FileReplicater interface { ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error } -const drStatusFile = "DR_STATE" +// DrStatusFile is the file name that stores the dr status. +const DrStatusFile = "DR_STATE" const persistFileTimeout = time.Second * 10 // ModeManager is used to control how raft logs are synchronized between @@ -331,7 +332,7 @@ func (m *ModeManager) drPersistStatusWithLock(status drAutoSyncStatus) { m.replicatedMembers = m.replicatedMembers[:0] for _, member := range members { - if err := m.fileReplicater.ReplicateFileToMember(ctx, member, drStatusFile, data); err != nil { + if err := m.fileReplicater.ReplicateFileToMember(ctx, member, DrStatusFile, data); err != nil { log.Warn("failed to switch state", zap.String("replicate-mode", modeDRAutoSync), zap.String("new-state", status.State), errs.ZapError(err)) // Throw away the error to make it possible to switch to async when // primary and dr DC are disconnected. This will result in the diff --git a/server/server.go b/server/server.go index 058aaa6f816..fa4f329dead 100644 --- a/server/server.go +++ b/server/server.go @@ -76,6 +76,7 @@ import ( "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/gc" syncer "github.com/tikv/pd/server/region_syncer" + "github.com/tikv/pd/server/replication" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" "go.etcd.io/etcd/pkg/types" @@ -1718,8 +1719,15 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, // PersistFile saves a file in DataDir. func (s *Server) PersistFile(name string, data []byte) error { + if name != replication.DrStatusFile { + return errors.New("Invalid file name") + } log.Info("persist file", zap.String("name", name), zap.Binary("data", data)) - return os.WriteFile(filepath.Join(s.GetConfig().DataDir, name), data, 0644) // #nosec + path := filepath.Join(s.GetConfig().DataDir, name) + if !isPathInDirectory(path, s.GetConfig().DataDir) { + return errors.New("Invalid file path") + } + return os.WriteFile(path, data, 0644) // #nosec } // SaveTTLConfig save ttl config diff --git a/server/server_test.go b/server/server_test.go index 47ec2dd735c..2d3ec175187 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/http" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -306,3 +307,15 @@ func TestAPIService(t *testing.T) { MustWaitLeader(re, []*Server{svr}) re.True(svr.IsAPIServiceMode()) } + +func TestIsPathInDirectory(t *testing.T) { + re := require.New(t) + fileName := "test" + directory := "/root/project" + path := filepath.Join(directory, fileName) + re.True(isPathInDirectory(path, directory)) + + fileName = "../../test" + path = filepath.Join(directory, fileName) + re.False(isPathInDirectory(path, directory)) +} diff --git a/server/util.go b/server/util.go index 9c7a97a9806..654b424465e 100644 --- a/server/util.go +++ b/server/util.go @@ -17,6 +17,7 @@ package server import ( "context" "net/http" + "path/filepath" "strings" "github.com/gorilla/mux" @@ -124,3 +125,17 @@ func combineBuilderServerHTTPService(ctx context.Context, svr *Server, serviceBu userHandlers[pdAPIPrefix] = apiService return userHandlers, nil } + +func isPathInDirectory(path, directory string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + + absDir, err := filepath.Abs(directory) + if err != nil { + return false + } + + return strings.HasPrefix(absPath, absDir) +}