From f7d858e8a25f5498d9a17e7e8333e037e3eb2e9b Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 15 Aug 2023 14:30:31 +0800 Subject: [PATCH] apiutil, middleware: strengthen the robustness of GetIPPortFromHTTPRequest function (#6958) close tikv/pd#6957 - Improve `GetIPPortFromHTTPRequest` to ensure it could handle different host addresses. - Make middleware set the forwarded header correctly. Signed-off-by: JmPotato --- pkg/audit/audit_test.go | 2 +- pkg/utils/apiutil/apiutil.go | 43 ++++--- pkg/utils/apiutil/apiutil_test.go | 139 ++++++++++++++++++++++ pkg/utils/apiutil/serverapi/middleware.go | 24 ++-- pkg/utils/requestutil/request_info.go | 9 +- server/apiv2/middlewares/redirector.go | 7 +- server/server.go | 2 +- server/server_test.go | 9 +- server/testutil.go | 2 +- tests/server/api/api_test.go | 8 +- tools/pd-ctl/pdctl/command/log_command.go | 3 +- 11 files changed, 204 insertions(+), 44 deletions(-) diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 42d742ed243..20f8c9344f7 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) { b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) re.Equal( - fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+ + fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), output[3], diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index dce063a99f9..269a256cff3 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -47,6 +47,17 @@ var ( ) const ( + // PDRedirectorHeader is used to mark which PD redirected this request. + PDRedirectorHeader = "PD-Redirector" + // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD. + PDAllowFollowerHandleHeader = "PD-Allow-follower-handle" + // XForwardedForHeader is used to mark the client IP. + XForwardedForHeader = "X-Forwarded-For" + // XForwardedPortHeader is used to mark the client port. + XForwardedPortHeader = "X-Forwarded-Port" + // XRealIPHeader is used to mark the real client IP. + XRealIPHeader = "X-Real-Ip" + // ErrRedirectFailed is the error message for redirect failed. ErrRedirectFailed = "redirect failed" // ErrRedirectToNotLeader is the error message for redirect to not leader. @@ -101,26 +112,30 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) { } } -// GetIPAddrFromHTTPRequest returns http client IP from context. +// GetIPPortFromHTTPRequest returns http client host IP and port from context. // Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension), // so `X-Forwarded-For` has the higher priority than `X-Real-IP`. // And both of them have the higher priority than `RemoteAddr` -func GetIPAddrFromHTTPRequest(r *http.Request) string { - ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",") - if len(strings.Trim(ips[0], " ")) > 0 { - return ips[0] - } - - ip := r.Header.Get("X-Real-Ip") - if ip != "" { - return ip +func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { + forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",") + if forwardedIP := strings.Trim(forwardedIPs[0], " "); len(forwardedIP) > 0 { + ip = forwardedIP + // Try to get the port from "X-Forwarded-Port" header. + forwardedPorts := strings.Split(r.Header.Get(XForwardedPortHeader), ",") + if forwardedPort := strings.Trim(forwardedPorts[0], " "); len(forwardedPort) > 0 { + port = forwardedPort + } + } else if realIP := r.Header.Get(XRealIPHeader); len(realIP) > 0 { + ip = realIP + } else { + ip = r.RemoteAddr } - - ip, _, err := net.SplitHostPort(r.RemoteAddr) + splitIP, splitPort, err := net.SplitHostPort(ip) if err != nil { - return "" + // Ensure we could get an IP address at least. + return ip, port } - return ip + return splitIP, splitPort } // GetComponentNameOnHTTP returns component name from Request Header diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index bbbb3b860fb..a4e7b97aa4d 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -17,6 +17,7 @@ package apiutil import ( "bytes" "io" + "net/http" "net/http/httptest" "testing" @@ -68,3 +69,141 @@ func TestJsonRespondErrorBadInput(t *testing.T) { re.Equal(400, result.StatusCode) } } + +func TestGetIPPortFromHTTPRequest(t *testing.T) { + t.Parallel() + re := require.New(t) + + testCases := []struct { + r *http.Request + ip string + port string + err error + }{ + // IPv4 "X-Forwarded-For" with port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"127.0.0.1:5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 "X-Forwarded-For" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"127.0.0.1"}, + XForwardedPortHeader: {"5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 "X-Real-IP" with port + { + r: &http.Request{ + Header: map[string][]string{ + XRealIPHeader: {"127.0.0.1:5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 "X-Real-IP" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"127.0.0.1"}, + XForwardedPortHeader: {"5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 RemoteAddr with port + { + r: &http.Request{ + RemoteAddr: "127.0.0.1:5299", + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 RemoteAddr without port + { + r: &http.Request{ + RemoteAddr: "127.0.0.1", + }, + ip: "127.0.0.1", + port: "", + }, + // IPv6 "X-Forwarded-For" with port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"[::1]:5299"}, + }, + }, + ip: "::1", + port: "5299", + }, + // IPv6 "X-Forwarded-For" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"::1"}, + }, + }, + ip: "::1", + port: "", + }, + // IPv6 "X-Real-IP" with port + { + r: &http.Request{ + Header: map[string][]string{ + XRealIPHeader: {"[::1]:5299"}, + }, + }, + ip: "::1", + port: "5299", + }, + // IPv6 "X-Real-IP" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"::1"}, + }, + }, + ip: "::1", + port: "", + }, + // IPv6 RemoteAddr with port + { + r: &http.Request{ + RemoteAddr: "[::1]:5299", + }, + ip: "::1", + port: "5299", + }, + // IPv6 RemoteAddr without port + { + r: &http.Request{ + RemoteAddr: "::1", + }, + ip: "::1", + port: "", + }, + // Abnormal case + { + r: &http.Request{}, + ip: "", + port: "", + }, + } + for idx, testCase := range testCases { + ip, port := GetIPPortFromHTTPRequest(testCase.r) + re.Equal(testCase.ip, ip, "case %d", idx) + re.Equal(testCase.port, port, "case %d", idx) + } +} diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 653ede75e7a..7d403ecef13 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -26,13 +26,6 @@ import ( "go.uber.org/zap" ) -// HTTP headers. -const ( - PDRedirectorHeader = "PD-Redirector" - PDAllowFollowerHandle = "PD-Allow-follower-handle" - ForwardedForHeader = "X-Forwarded-For" -) - type runtimeServiceValidator struct { s *server.Server group apiutil.APIServiceGroup @@ -130,7 +123,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r) - allowFollowerHandle := len(r.Header.Get(PDAllowFollowerHandle)) > 0 + allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := h.s.GetMember().IsLeader() if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag { next(w, r) @@ -138,14 +131,23 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } // Prevent more than one redirection. - if name := r.Header.Get(PDRedirectorHeader); len(name) != 0 { + if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) http.Error(w, apiutil.ErrRedirectToNotLeader, http.StatusInternalServerError) return } - r.Header.Set(PDRedirectorHeader, h.s.Name()) - r.Header.Add(ForwardedForHeader, r.RemoteAddr) + r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) + forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r) + if len(forwardedIP) > 0 { + r.Header.Add(apiutil.XForwardedForHeader, forwardedIP) + } else { + // Fallback if GetIPPortFromHTTPRequest failed to get the IP. + r.Header.Add(apiutil.XForwardedForHeader, r.RemoteAddr) + } + if len(forwardedPort) > 0 { + r.Header.Add(apiutil.XForwardedPortHeader, forwardedPort) + } var clientUrls []string if matchedFlag { diff --git a/pkg/utils/requestutil/request_info.go b/pkg/utils/requestutil/request_info.go index 73a7e299e16..40724bb790f 100644 --- a/pkg/utils/requestutil/request_info.go +++ b/pkg/utils/requestutil/request_info.go @@ -31,24 +31,27 @@ type RequestInfo struct { Method string Component string IP string + Port string URLParam string BodyParam string StartTimeStamp int64 } func (info *RequestInfo) String() string { - s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", - info.ServiceLabel, info.Method, info.Component, info.IP, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) + s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", + info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) return s } // GetRequestInfo returns request info needed from http.Request func GetRequestInfo(r *http.Request) RequestInfo { + ip, port := apiutil.GetIPPortFromHTTPRequest(r) return RequestInfo{ ServiceLabel: apiutil.GetRouteName(r), Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path), Component: apiutil.GetComponentNameOnHTTP(r), - IP: apiutil.GetIPAddrFromHTTPRequest(r), + IP: ip, + Port: port, URLParam: getURLParam(r), BodyParam: getBodyParam(r), StartTimeStamp: time.Now().Unix(), diff --git a/server/apiv2/middlewares/redirector.go b/server/apiv2/middlewares/redirector.go index 5539dd089dc..285f096e823 100644 --- a/server/apiv2/middlewares/redirector.go +++ b/server/apiv2/middlewares/redirector.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/utils/apiutil" - "github.com/tikv/pd/pkg/utils/apiutil/serverapi" "github.com/tikv/pd/server" "go.uber.org/zap" ) @@ -31,7 +30,7 @@ import ( func Redirector() gin.HandlerFunc { return func(c *gin.Context) { svr := c.MustGet(ServerContextKey).(*server.Server) - allowFollowerHandle := len(c.Request.Header.Get(serverapi.PDAllowFollowerHandle)) > 0 + allowFollowerHandle := len(c.Request.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := svr.GetMember().IsLeader() if !svr.IsClosed() && (allowFollowerHandle || isLeader) { c.Next() @@ -39,13 +38,13 @@ func Redirector() gin.HandlerFunc { } // Prevent more than one redirection. - if name := c.Request.Header.Get(serverapi.PDRedirectorHeader); len(name) != 0 { + if name := c.Request.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", svr.Name()), errs.ZapError(errs.ErrRedirect)) c.AbortWithStatusJSON(http.StatusInternalServerError, errs.ErrRedirect.FastGenByArgs().Error()) return } - c.Request.Header.Set(serverapi.PDRedirectorHeader, svr.Name()) + c.Request.Header.Set(apiutil.PDRedirectorHeader, svr.Name()) leader := svr.GetMember().GetLeader() if leader == nil { diff --git a/server/server.go b/server/server.go index dd46b814f1d..36919805d40 100644 --- a/server/server.go +++ b/server/server.go @@ -1750,7 +1750,7 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, } url := clientUrls[0] + filepath.Join("/pd/api/v1/admin/persist-file", name) req, _ := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) - req.Header.Set("PD-Allow-follower-handle", "true") + req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true") res, err := s.httpClient.Do(req) if err != nil { log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName()), errs.ZapError(err)) diff --git a/server/server_test.go b/server/server_test.go index 47ec2dd735c..2d0e23c7682 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/assertutil" "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/testutil" @@ -218,7 +219,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil) suite.NoError(err) - req.Header.Add("X-Forwarded-For", "127.0.0.2") + req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2") resp, err := http.DefaultClient.Do(req) suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) @@ -248,7 +249,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil) suite.NoError(err) - req.Header.Add("X-Real-Ip", "127.0.0.2") + req.Header.Add(apiutil.XRealIPHeader, "127.0.0.2") resp, err := http.DefaultClient.Do(req) suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) @@ -278,8 +279,8 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil) suite.NoError(err) - req.Header.Add("X-Forwarded-For", "127.0.0.2") - req.Header.Add("X-Real-Ip", "127.0.0.3") + req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2") + req.Header.Add(apiutil.XRealIPHeader, "127.0.0.3") resp, err := http.DefaultClient.Do(req) suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) diff --git a/server/testutil.go b/server/testutil.go index 506139e20f1..cc1a380bfb8 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -143,7 +143,7 @@ func CreateMockHandler(re *require.Assertions, ip string) HandlerBuilder { mux.HandleFunc("/pd/apis/mock/v1/hello", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") // test getting ip - clientIP := apiutil.GetIPAddrFromHTTPRequest(r) + clientIP, _ := apiutil.GetIPPortFromHTTPRequest(r) re.Equal(ip, clientIP) }) info := apiutil.APIServiceGroup{ diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index afe2baf81f3..4533073f077 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -34,7 +34,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" - "github.com/tikv/pd/pkg/utils/apiutil/serverapi" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/server" @@ -609,10 +609,10 @@ func (suite *redirectorTestSuite) TestAllowFollowerHandle() { addr := follower.GetAddr() + "/pd/api/v1/version" request, err := http.NewRequest(http.MethodGet, addr, nil) suite.NoError(err) - request.Header.Add(serverapi.PDAllowFollowerHandle, "true") + request.Header.Add(apiutil.PDAllowFollowerHandleHeader, "true") resp, err := dialClient.Do(request) suite.NoError(err) - suite.Equal("", resp.Header.Get(serverapi.PDRedirectorHeader)) + suite.Equal("", resp.Header.Get(apiutil.PDRedirectorHeader)) defer resp.Body.Close() suite.Equal(http.StatusOK, resp.StatusCode) _, err = io.ReadAll(resp.Body) @@ -643,7 +643,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { // Request to follower with redirectorHeader will fail. request.RequestURI = "" - request.Header.Set(serverapi.PDRedirectorHeader, "pd") + request.Header.Set(apiutil.PDRedirectorHeader, "pd") resp1, err := dialClient.Do(request) suite.NoError(err) defer resp1.Body.Close() diff --git a/tools/pd-ctl/pdctl/command/log_command.go b/tools/pd-ctl/pdctl/command/log_command.go index ec22884ecad..56c4438a6c3 100644 --- a/tools/pd-ctl/pdctl/command/log_command.go +++ b/tools/pd-ctl/pdctl/command/log_command.go @@ -20,6 +20,7 @@ import ( "net/http" "github.com/spf13/cobra" + "github.com/tikv/pd/pkg/utils/apiutil" ) var ( @@ -55,7 +56,7 @@ func logCommandFunc(cmd *cobra.Command, args []string) { cmd.Printf("Failed to parse address %v: %s\n", args[1], err) return } - _, err = doRequestSingleEndpoint(cmd, url, logPrefix, http.MethodPost, http.Header{"Content-Type": {"application/json"}, "PD-Allow-follower-handle": {"true"}}, + _, err = doRequestSingleEndpoint(cmd, url, logPrefix, http.MethodPost, http.Header{"Content-Type": {"application/json"}, apiutil.PDAllowFollowerHandleHeader: {"true"}}, WithBody(bytes.NewBuffer(data))) if err != nil { cmd.Printf("Failed to set %v log level: %s\n", args[1], err)