Skip to content

Commit

Permalink
api: support record the caller IP when forwarding (tikv#6622)
Browse files Browse the repository at this point in the history
close tikv#6595, close tikv#6598

Signed-off-by: Ryan Leung <[email protected]>

Co-authored-by: ShuNing <[email protected]>
  • Loading branch information
rleungx and nolouch committed Jun 20, 2023
1 parent 77083e6 commit 58d9208
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
2 changes: 2 additions & 0 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
const (
PDRedirectorHeader = "PD-Redirector"
PDAllowFollowerHandle = "PD-Allow-follower-handle"
ForwardedForHeader = "X-Forwarded-For"
)

type runtimeServiceValidator struct {
Expand Down Expand Up @@ -144,6 +145,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http
}

r.Header.Set(PDRedirectorHeader, h.s.Name())
r.Header.Add(ForwardedForHeader, r.RemoteAddr)

var clientUrls []string
if matchedFlag {
Expand Down
4 changes: 2 additions & 2 deletions server/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router {

srd := createStreamingRender()
regionsAllHandler := newRegionsHandler(svr, srd)
registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus))

regionsHandler := newRegionsHandler(svr, rd)
registerFunc(clusterRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
Expand Down Expand Up @@ -288,7 +288,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router {
registerFunc(apiRouter, "/leader/transfer/{next_leader}", leaderHandler.TransferLeader, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus))

statsHandler := newStatsHandler(svr, rd)
registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus))

trendHandler := newTrendHandler(svr, rd)
registerFunc(apiRouter, "/trend", trendHandler.GetTrend, setMethods(http.MethodGet), setAuditBackend(prometheus))
Expand Down
29 changes: 27 additions & 2 deletions tests/server/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() {

func (suite *middlewareTestSuite) TestAuditLocalLogBackend() {
tempStdoutFile, _ := os.CreateTemp("/tmp", "pd_tests")
defer os.Remove(tempStdoutFile.Name())
cfg := &log.Config{}
cfg.File.Filename = tempStdoutFile.Name()
cfg.Level = "info"
Expand All @@ -471,8 +472,6 @@ func (suite *middlewareTestSuite) TestAuditLocalLogBackend() {
suite.Contains(string(b), "audit log")
suite.NoError(err)
suite.Equal(http.StatusOK, resp.StatusCode)

os.Remove(tempStdoutFile.Name())
}

func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) {
Expand Down Expand Up @@ -656,6 +655,32 @@ func (suite *redirectorTestSuite) TestNotLeader() {
suite.NoError(err)
}

func (suite *redirectorTestSuite) TestXForwardedFor() {
leader := suite.cluster.GetServer(suite.cluster.GetLeader())
suite.NoError(leader.BootstrapCluster())
tempStdoutFile, _ := os.CreateTemp("/tmp", "pd_tests")
defer os.Remove(tempStdoutFile.Name())
cfg := &log.Config{}
cfg.File.Filename = tempStdoutFile.Name()
cfg.Level = "info"
lg, p, _ := log.InitLogger(cfg)
log.ReplaceGlobals(lg, p)

follower := suite.cluster.GetServer(suite.cluster.GetFollower())
addr := follower.GetAddr() + "/pd/api/v1/regions"
request, err := http.NewRequest(http.MethodGet, addr, nil)
suite.NoError(err)
resp, err := dialClient.Do(request)
suite.NoError(err)
defer resp.Body.Close()
suite.Equal(http.StatusOK, resp.StatusCode)
time.Sleep(1 * time.Second)
b, _ := os.ReadFile(tempStdoutFile.Name())
l := string(b)
suite.Contains(l, "/pd/api/v1/regions")
suite.NotContains(l, suite.cluster.GetConfig().GetClientURLs())
}

func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header {
resp, err := dialClient.Get(s.GetAddr() + "/pd/api/v1/version")
re.NoError(err)
Expand Down

0 comments on commit 58d9208

Please sign in to comment.