diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index fb5d20f930e..653ede75e7a 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -30,6 +30,7 @@ import ( const ( PDRedirectorHeader = "PD-Redirector" PDAllowFollowerHandle = "PD-Allow-follower-handle" + ForwardedForHeader = "X-Forwarded-For" ) type runtimeServiceValidator struct { @@ -144,6 +145,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } r.Header.Set(PDRedirectorHeader, h.s.Name()) + r.Header.Add(ForwardedForHeader, r.RemoteAddr) var clientUrls []string if matchedFlag { diff --git a/server/api/router.go b/server/api/router.go index f1cfd13a60d..5ec74908c0d 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -240,7 +240,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { srd := createStreamingRender() regionsAllHandler := newRegionsHandler(svr, srd) - registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) regionsHandler := newRegionsHandler(svr, rd) registerFunc(clusterRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) @@ -289,7 +289,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/leader/transfer/{next_leader}", leaderHandler.TransferLeader, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) statsHandler := newStatsHandler(svr, rd) - registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) trendHandler := newTrendHandler(svr, rd) registerFunc(apiRouter, "/trend", trendHandler.GetTrend, setMethods(http.MethodGet), setAuditBackend(prometheus)) diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 9a717003b9f..afe2baf81f3 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" @@ -651,6 +652,32 @@ func (suite *redirectorTestSuite) TestNotLeader() { suite.NoError(err) } +func (suite *redirectorTestSuite) TestXForwardedFor() { + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + suite.NoError(leader.BootstrapCluster()) + tempStdoutFile, _ := os.CreateTemp("/tmp", "pd_tests") + defer os.Remove(tempStdoutFile.Name()) + cfg := &log.Config{} + cfg.File.Filename = tempStdoutFile.Name() + cfg.Level = "info" + lg, p, _ := log.InitLogger(cfg) + log.ReplaceGlobals(lg, p) + + follower := suite.cluster.GetServer(suite.cluster.GetFollower()) + addr := follower.GetAddr() + "/pd/api/v1/regions" + request, err := http.NewRequest(http.MethodGet, addr, nil) + suite.NoError(err) + resp, err := dialClient.Do(request) + suite.NoError(err) + defer resp.Body.Close() + suite.Equal(http.StatusOK, resp.StatusCode) + time.Sleep(1 * time.Second) + b, _ := os.ReadFile(tempStdoutFile.Name()) + l := string(b) + suite.Contains(l, "/pd/api/v1/regions") + suite.NotContains(l, suite.cluster.GetConfig().GetClientURLs()) +} + func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { resp, err := dialClient.Get(s.GetAddr() + "/pd/api/v1/version") re.NoError(err)