Skip to content

Commit

Permalink
redo IP parsing for changed requirements (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
irshadaj authored Jan 29, 2024
1 parent fd32507 commit 2a14ff5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
25 changes: 14 additions & 11 deletions cmd/api/src/api/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ func ContextMiddleware(next http.Handler) http.Handler {
requestCtx, cancel := context.WithTimeout(request.Context(), requestedWaitDuration.Value)
defer cancel()
// Insert the bh context
var ipAddress string
if ipAddress, err = parseUserIP(request); err != nil {
log.Errorf("requestIP not set: %v", err)
}

requestCtx = ctx.Set(requestCtx, &ctx.Context{
StartTime: startTime,
Expand All @@ -143,7 +139,7 @@ func ContextMiddleware(next http.Handler) http.Handler {
Scheme: getScheme(request),
Host: request.Host,
},
RequestIP: ipAddress,
RequestIP: parseUserIP(request),
})

// Route the request with the embedded context
Expand All @@ -152,16 +148,23 @@ func ContextMiddleware(next http.Handler) http.Handler {
})
}

func parseUserIP(r *http.Request) (string, error) {
func parseUserIP(r *http.Request) string {
res := ""
if ipAddress := r.Header.Get("X-Forwarded-For"); ipAddress != "" {
return strings.Split(ipAddress, ",")[0], nil
} else if parsedUrl, err := url.Parse(r.RemoteAddr); err != nil {
return "", fmt.Errorf("error parsing IP address from RemoteAddr: %s", err)
res += "X-Forwarded-For: " + ipAddress + "; "
} else {
log.Errorf("No data found in X-Forwarded-For")
}

if parsedUrl, err := url.Parse(r.RemoteAddr); err != nil {
log.Errorf("Error parsing IP address from RemoteAddr: %s", err)
} else if hostName := parsedUrl.Hostname(); hostName == "" {
return "", fmt.Errorf("hostname not found in URL: %s", parsedUrl.String())
log.Errorf("Hostname not found in URL: %s", parsedUrl.String())
} else {
return parsedUrl.Hostname(), nil
res += "Remote Address: " + parsedUrl.Hostname()
}

return res
}

func ParseHeaderValues(values string) map[string]string {
Expand Down
42 changes: 29 additions & 13 deletions cmd/api/src/api/middleware/middleware_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,35 +80,51 @@ func TestRequestWaitDuration(t *testing.T) {
require.True(t, requestedWaitDuration.UserSet)
}

func TestParseUserIP_XForwardedFor(t *testing.T) {
func TestParseUserIP_XForwardedForMissing(t *testing.T) {
req, err := http.NewRequest("GET", "/teapot", nil)
require.Nil(t, err)

ip1 := "192.168.1.1:8080"
ip2 := "192.168.1.2"
ip3 := "192.168.1.3"
req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ","))
req.RemoteAddr = "http://www.google.com/0.0.0.0:3000"

ip, err := parseUserIP(req)
require.Nil(t, err)
require.Equal(t, ip1, ip)
res := parseUserIP(req)
require.NotContains(t, res, "X-Forwarded-For")
require.Contains(t, res, "Remote Address")
}

func TestParseUserIP_RemoteAddrError(t *testing.T) {
req, err := http.NewRequest("GET", "/teapot", nil)
require.Nil(t, err)

ip1 := "192.168.1.1:8080"
ip2 := "192.168.1.2"
ip3 := "192.168.1.3"
req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ","))
req.RemoteAddr = "0.0.0.0:3000"

_, err = parseUserIP(req)
require.Contains(t, err.Error(), "error parsing IP address")
res := parseUserIP(req)
require.Contains(t, res, "X-Forwarded-For")
require.Contains(t, res, ip1)
require.Contains(t, res, ip2)
require.NotContains(t, res, "Remote Address")
}

func TestParseUserIP_HostnameError(t *testing.T) {
func TestParseUserIP_Success(t *testing.T) {
req, err := http.NewRequest("GET", "/teapot", nil)
require.Nil(t, err)

_, err = parseUserIP(req)
require.Contains(t, err.Error(), "hostname")
ip1 := "192.168.1.1:8080"
ip2 := "192.168.1.2"
ip3 := "192.168.1.3"
req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ","))

req.RemoteAddr = "http://www.google.com/0.0.0.0:3000"

res := parseUserIP(req)
require.Contains(t, res, "X-Forwarded-For")
require.Contains(t, res, ip1)
require.Contains(t, res, ip2)
require.Contains(t, res, ip3)
require.Contains(t, res, "Remote Address")
}

func TestParsePreferHeaderWait(t *testing.T) {
Expand Down

0 comments on commit 2a14ff5

Please sign in to comment.