Skip to content

Commit

Permalink
tasks 6 and 7 for #345 (#357)
Browse files Browse the repository at this point in the history
* updated IP address parsing

* added tests for IP parsing, changed IP delimiter to comma
  • Loading branch information
irshadaj authored Jan 26, 2024
1 parent e1d81bd commit 3ac13c1
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
28 changes: 15 additions & 13 deletions cmd/api/src/api/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func ContextMiddleware(next http.Handler) http.Handler {
requestCtx, cancel := context.WithTimeout(request.Context(), requestedWaitDuration.Value)
defer cancel()
// Insert the bh context
var ipAddress string
if ipAddress, err = parseUserIP(request); err != nil {
log.Errorf("requestIP not set: %v", err)
}

requestCtx = ctx.Set(requestCtx, &ctx.Context{
StartTime: startTime,
Timeout: requestedWaitDuration,
Expand All @@ -138,7 +143,7 @@ func ContextMiddleware(next http.Handler) http.Handler {
Scheme: getScheme(request),
Host: request.Host,
},
RequestIP: parseUserIP(request),
RequestIP: ipAddress,
})

// Route the request with the embedded context
Expand All @@ -147,19 +152,16 @@ func ContextMiddleware(next http.Handler) http.Handler {
})
}

func parseUserIP(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
if parsedUrl, err := url.Parse(r.RemoteAddr); err != nil {
log.Errorf("error parsing IP address from RemoteAddr: %s", err)
} else {
IPAddress = parsedUrl.Hostname()
}
func parseUserIP(r *http.Request) (string, error) {
if ipAddress := r.Header.Get("X-Forwarded-For"); ipAddress != "" {
return strings.Split(ipAddress, ",")[0], nil
} else if parsedUrl, err := url.Parse(r.RemoteAddr); err != nil {
return "", fmt.Errorf("error parsing IP address from RemoteAddr: %s", err)
} else if hostName := parsedUrl.Hostname(); hostName == "" {
return "", fmt.Errorf("hostname not found in URL: %s", parsedUrl.String())
} else {
return parsedUrl.Hostname(), nil
}
return IPAddress
}

func ParseHeaderValues(values string) map[string]string {
Expand Down
32 changes: 32 additions & 0 deletions cmd/api/src/api/middleware/middleware_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/tls"
"net/http"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -79,6 +80,37 @@ func TestRequestWaitDuration(t *testing.T) {
require.True(t, requestedWaitDuration.UserSet)
}

func TestParseUserIP_XForwardedFor(t *testing.T) {
req, err := http.NewRequest("GET", "/teapot", nil)
require.Nil(t, err)

ip1 := "192.168.1.1:8080"
ip2 := "192.168.1.2"
ip3 := "192.168.1.3"
req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ","))

ip, err := parseUserIP(req)
require.Nil(t, err)
require.Equal(t, ip1, ip)
}

func TestParseUserIP_RemoteAddrError(t *testing.T) {
req, err := http.NewRequest("GET", "/teapot", nil)
require.Nil(t, err)
req.RemoteAddr = "0.0.0.0:3000"

_, err = parseUserIP(req)
require.Contains(t, err.Error(), "error parsing IP address")
}

func TestParseUserIP_HostnameError(t *testing.T) {
req, err := http.NewRequest("GET", "/teapot", nil)
require.Nil(t, err)

_, err = parseUserIP(req)
require.Contains(t, err.Error(), "hostname")
}

func TestParsePreferHeaderWait(t *testing.T) {
_, err := parsePreferHeaderWait("wait=1.5", 30*time.Second)
require.NotNil(t, err)
Expand Down
3 changes: 2 additions & 1 deletion cmd/api/src/ctx/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ const (
func NewAuditLogFromContext(ctx Context, idResolver auth.IdentityResolver) (model.AuditLog, error) {
if ctx.AuditCtx.Model == nil {
return model.AuditLog{}, fmt.Errorf("model cannot be nil when creating a new audit log")
} else if ctx.AuditCtx.Action != model.AuditStatusFailure && ctx.AuditCtx.Action != model.AuditStatusSuccess {
return model.AuditLog{}, fmt.Errorf("invalid action specified in audit log: %s", ctx.AuditCtx.Action)
}
//TODO: Add a check for empty status to prevent nil pointer references
authContext := ctx.AuthCtx

if !authContext.Authenticated() {
Expand Down

0 comments on commit 3ac13c1

Please sign in to comment.