Skip to content

Commit

Permalink
apiutil, middleware: strengthen the robustness of GetIPPortFromHTTPRe…
Browse files Browse the repository at this point in the history
…quest function (tikv#6958)

close tikv#6957

- Improve `GetIPPortFromHTTPRequest` to ensure it could handle different host addresses.
- Make middleware set the forwarded header correctly.

Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato authored and rleungx committed Dec 1, 2023
1 parent a57ddc9 commit f7d858e
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pkg/audit/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) {
b, _ := os.ReadFile(fname)
output := strings.SplitN(string(b), "]", 4)
re.Equal(
fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+
fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+
"StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n",
time.Unix(info.StartTimeStamp, 0).String()),
output[3],
Expand Down
43 changes: 29 additions & 14 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ var (
)

const (
// PDRedirectorHeader is used to mark which PD redirected this request.
PDRedirectorHeader = "PD-Redirector"
// PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD.
PDAllowFollowerHandleHeader = "PD-Allow-follower-handle"
// XForwardedForHeader is used to mark the client IP.
XForwardedForHeader = "X-Forwarded-For"
// XForwardedPortHeader is used to mark the client port.
XForwardedPortHeader = "X-Forwarded-Port"
// XRealIPHeader is used to mark the real client IP.
XRealIPHeader = "X-Real-Ip"

// ErrRedirectFailed is the error message for redirect failed.
ErrRedirectFailed = "redirect failed"
// ErrRedirectToNotLeader is the error message for redirect to not leader.
Expand Down Expand Up @@ -101,26 +112,30 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) {
}
}

// GetIPAddrFromHTTPRequest returns http client IP from context.
// GetIPPortFromHTTPRequest returns http client host IP and port from context.
// Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension),
// so `X-Forwarded-For` has the higher priority than `X-Real-IP`.
// And both of them have the higher priority than `RemoteAddr`
func GetIPAddrFromHTTPRequest(r *http.Request) string {
ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",")
if len(strings.Trim(ips[0], " ")) > 0 {
return ips[0]
}

ip := r.Header.Get("X-Real-Ip")
if ip != "" {
return ip
func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) {
forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",")
if forwardedIP := strings.Trim(forwardedIPs[0], " "); len(forwardedIP) > 0 {
ip = forwardedIP
// Try to get the port from "X-Forwarded-Port" header.
forwardedPorts := strings.Split(r.Header.Get(XForwardedPortHeader), ",")
if forwardedPort := strings.Trim(forwardedPorts[0], " "); len(forwardedPort) > 0 {
port = forwardedPort
}
} else if realIP := r.Header.Get(XRealIPHeader); len(realIP) > 0 {
ip = realIP
} else {
ip = r.RemoteAddr
}

ip, _, err := net.SplitHostPort(r.RemoteAddr)
splitIP, splitPort, err := net.SplitHostPort(ip)
if err != nil {
return ""
// Ensure we could get an IP address at least.
return ip, port
}
return ip
return splitIP, splitPort
}

// GetComponentNameOnHTTP returns component name from Request Header
Expand Down
139 changes: 139 additions & 0 deletions pkg/utils/apiutil/apiutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package apiutil
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"

Expand Down Expand Up @@ -68,3 +69,141 @@ func TestJsonRespondErrorBadInput(t *testing.T) {
re.Equal(400, result.StatusCode)
}
}

func TestGetIPPortFromHTTPRequest(t *testing.T) {
t.Parallel()
re := require.New(t)

testCases := []struct {
r *http.Request
ip string
port string
err error
}{
// IPv4 "X-Forwarded-For" with port
{
r: &http.Request{
Header: map[string][]string{
XForwardedForHeader: {"127.0.0.1:5299"},
},
},
ip: "127.0.0.1",
port: "5299",
},
// IPv4 "X-Forwarded-For" without port
{
r: &http.Request{
Header: map[string][]string{
XForwardedForHeader: {"127.0.0.1"},
XForwardedPortHeader: {"5299"},
},
},
ip: "127.0.0.1",
port: "5299",
},
// IPv4 "X-Real-IP" with port
{
r: &http.Request{
Header: map[string][]string{
XRealIPHeader: {"127.0.0.1:5299"},
},
},
ip: "127.0.0.1",
port: "5299",
},
// IPv4 "X-Real-IP" without port
{
r: &http.Request{
Header: map[string][]string{
XForwardedForHeader: {"127.0.0.1"},
XForwardedPortHeader: {"5299"},
},
},
ip: "127.0.0.1",
port: "5299",
},
// IPv4 RemoteAddr with port
{
r: &http.Request{
RemoteAddr: "127.0.0.1:5299",
},
ip: "127.0.0.1",
port: "5299",
},
// IPv4 RemoteAddr without port
{
r: &http.Request{
RemoteAddr: "127.0.0.1",
},
ip: "127.0.0.1",
port: "",
},
// IPv6 "X-Forwarded-For" with port
{
r: &http.Request{
Header: map[string][]string{
XForwardedForHeader: {"[::1]:5299"},
},
},
ip: "::1",
port: "5299",
},
// IPv6 "X-Forwarded-For" without port
{
r: &http.Request{
Header: map[string][]string{
XForwardedForHeader: {"::1"},
},
},
ip: "::1",
port: "",
},
// IPv6 "X-Real-IP" with port
{
r: &http.Request{
Header: map[string][]string{
XRealIPHeader: {"[::1]:5299"},
},
},
ip: "::1",
port: "5299",
},
// IPv6 "X-Real-IP" without port
{
r: &http.Request{
Header: map[string][]string{
XForwardedForHeader: {"::1"},
},
},
ip: "::1",
port: "",
},
// IPv6 RemoteAddr with port
{
r: &http.Request{
RemoteAddr: "[::1]:5299",
},
ip: "::1",
port: "5299",
},
// IPv6 RemoteAddr without port
{
r: &http.Request{
RemoteAddr: "::1",
},
ip: "::1",
port: "",
},
// Abnormal case
{
r: &http.Request{},
ip: "",
port: "",
},
}
for idx, testCase := range testCases {
ip, port := GetIPPortFromHTTPRequest(testCase.r)
re.Equal(testCase.ip, ip, "case %d", idx)
re.Equal(testCase.port, port, "case %d", idx)
}
}
24 changes: 13 additions & 11 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ import (
"go.uber.org/zap"
)

// HTTP headers.
const (
PDRedirectorHeader = "PD-Redirector"
PDAllowFollowerHandle = "PD-Allow-follower-handle"
ForwardedForHeader = "X-Forwarded-For"
)

type runtimeServiceValidator struct {
s *server.Server
group apiutil.APIServiceGroup
Expand Down Expand Up @@ -130,22 +123,31 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri

func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r)
allowFollowerHandle := len(r.Header.Get(PDAllowFollowerHandle)) > 0
allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0
isLeader := h.s.GetMember().IsLeader()
if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag {
next(w, r)
return
}

// Prevent more than one redirection.
if name := r.Header.Get(PDRedirectorHeader); len(name) != 0 {
if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 {
log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect))
http.Error(w, apiutil.ErrRedirectToNotLeader, http.StatusInternalServerError)
return
}

r.Header.Set(PDRedirectorHeader, h.s.Name())
r.Header.Add(ForwardedForHeader, r.RemoteAddr)
r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name())
forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r)
if len(forwardedIP) > 0 {
r.Header.Add(apiutil.XForwardedForHeader, forwardedIP)
} else {
// Fallback if GetIPPortFromHTTPRequest failed to get the IP.
r.Header.Add(apiutil.XForwardedForHeader, r.RemoteAddr)
}
if len(forwardedPort) > 0 {
r.Header.Add(apiutil.XForwardedPortHeader, forwardedPort)
}

var clientUrls []string
if matchedFlag {
Expand Down
9 changes: 6 additions & 3 deletions pkg/utils/requestutil/request_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,27 @@ type RequestInfo struct {
Method string
Component string
IP string
Port string
URLParam string
BodyParam string
StartTimeStamp int64
}

func (info *RequestInfo) String() string {
s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, StartTime:%s, URLParam:%s, BodyParam:%s}",
info.ServiceLabel, info.Method, info.Component, info.IP, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam)
s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}",
info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam)
return s
}

// GetRequestInfo returns request info needed from http.Request
func GetRequestInfo(r *http.Request) RequestInfo {
ip, port := apiutil.GetIPPortFromHTTPRequest(r)
return RequestInfo{
ServiceLabel: apiutil.GetRouteName(r),
Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path),
Component: apiutil.GetComponentNameOnHTTP(r),
IP: apiutil.GetIPAddrFromHTTPRequest(r),
IP: ip,
Port: port,
URLParam: getURLParam(r),
BodyParam: getBodyParam(r),
StartTimeStamp: time.Now().Unix(),
Expand Down
7 changes: 3 additions & 4 deletions server/apiv2/middlewares/redirector.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/apiutil/serverapi"
"github.com/tikv/pd/server"
"go.uber.org/zap"
)
Expand All @@ -31,21 +30,21 @@ import (
func Redirector() gin.HandlerFunc {
return func(c *gin.Context) {
svr := c.MustGet(ServerContextKey).(*server.Server)
allowFollowerHandle := len(c.Request.Header.Get(serverapi.PDAllowFollowerHandle)) > 0
allowFollowerHandle := len(c.Request.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0
isLeader := svr.GetMember().IsLeader()
if !svr.IsClosed() && (allowFollowerHandle || isLeader) {
c.Next()
return
}

// Prevent more than one redirection.
if name := c.Request.Header.Get(serverapi.PDRedirectorHeader); len(name) != 0 {
if name := c.Request.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 {
log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", svr.Name()), errs.ZapError(errs.ErrRedirect))
c.AbortWithStatusJSON(http.StatusInternalServerError, errs.ErrRedirect.FastGenByArgs().Error())
return
}

c.Request.Header.Set(serverapi.PDRedirectorHeader, svr.Name())
c.Request.Header.Set(apiutil.PDRedirectorHeader, svr.Name())

leader := svr.GetMember().GetLeader()
if leader == nil {
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member,
}
url := clientUrls[0] + filepath.Join("/pd/api/v1/admin/persist-file", name)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
req.Header.Set("PD-Allow-follower-handle", "true")
req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true")
res, err := s.httpClient.Do(req)
if err != nil {
log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName()), errs.ZapError(err))
Expand Down
9 changes: 5 additions & 4 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/tikv/pd/pkg/mcs/utils"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/assertutil"
"github.com/tikv/pd/pkg/utils/etcdutil"
"github.com/tikv/pd/pkg/utils/testutil"
Expand Down Expand Up @@ -218,7 +219,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() {

req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil)
suite.NoError(err)
req.Header.Add("X-Forwarded-For", "127.0.0.2")
req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2")
resp, err := http.DefaultClient.Do(req)
suite.NoError(err)
suite.Equal(http.StatusOK, resp.StatusCode)
Expand Down Expand Up @@ -248,7 +249,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() {

req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil)
suite.NoError(err)
req.Header.Add("X-Real-Ip", "127.0.0.2")
req.Header.Add(apiutil.XRealIPHeader, "127.0.0.2")
resp, err := http.DefaultClient.Do(req)
suite.NoError(err)
suite.Equal(http.StatusOK, resp.StatusCode)
Expand Down Expand Up @@ -278,8 +279,8 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() {

req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil)
suite.NoError(err)
req.Header.Add("X-Forwarded-For", "127.0.0.2")
req.Header.Add("X-Real-Ip", "127.0.0.3")
req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2")
req.Header.Add(apiutil.XRealIPHeader, "127.0.0.3")
resp, err := http.DefaultClient.Do(req)
suite.NoError(err)
suite.Equal(http.StatusOK, resp.StatusCode)
Expand Down
2 changes: 1 addition & 1 deletion server/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func CreateMockHandler(re *require.Assertions, ip string) HandlerBuilder {
mux.HandleFunc("/pd/apis/mock/v1/hello", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello World")
// test getting ip
clientIP := apiutil.GetIPAddrFromHTTPRequest(r)
clientIP, _ := apiutil.GetIPPortFromHTTPRequest(r)
re.Equal(ip, clientIP)
})
info := apiutil.APIServiceGroup{
Expand Down
Loading

0 comments on commit f7d858e

Please sign in to comment.