Skip to content

Commit

Permalink
Bed 4345: removing API request timeouts (#622)
Browse files Browse the repository at this point in the history
* removed netTimeoutSeconds

* added context deadline checks in signature verification

* log unhandled error, remove unused params from middlewareLogging
  • Loading branch information
irshadaj authored May 22, 2024
1 parent a2b26df commit 96380dd
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 67 deletions.
2 changes: 1 addition & 1 deletion cmd/api/src/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (s authenticator) ValidateRequestSignature(tokenID uuid.UUID, request *http
}
}

if digestNow, err := NewRequestSignature(sha256.New, authToken.Key, requestDate.Format(time.RFC3339), request.Method, request.RequestURI, teeReader); err != nil {
if digestNow, err := NewRequestSignature(request.Context(), sha256.New, authToken.Key, requestDate.Format(time.RFC3339), request.Method, request.RequestURI, teeReader); err != nil {
if readCloser != nil {
readCloser.Close()
}
Expand Down
16 changes: 8 additions & 8 deletions cmd/api/src/api/auth_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func TestValidateRequestSignature(t *testing.T) {
require.NoError(t, err)

req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))
signature, err := NewRequestSignature(sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand All @@ -235,7 +235,7 @@ func TestValidateRequestSignature(t *testing.T) {
require.NoError(t, err)

req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))
signature, err := NewRequestSignature(sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand All @@ -260,7 +260,7 @@ func TestValidateRequestSignature(t *testing.T) {
require.NoError(t, err)

req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))
signature, err := NewRequestSignature(sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand Down Expand Up @@ -290,7 +290,7 @@ func TestValidateRequestSignature(t *testing.T) {

badRequestDate := time.Now().Add(-1 * time.Hour).Format(time.RFC3339)
req.Header.Add(headers.RequestDate.String(), badRequestDate)
signature, err := NewRequestSignature(sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand All @@ -317,7 +317,7 @@ func TestValidateRequestSignature(t *testing.T) {

req.ContentLength = int64(len(payload))
req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))
signature, err := NewRequestSignature(sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, bytes.NewBuffer(payload))
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, bytes.NewBuffer(payload))
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand Down Expand Up @@ -359,7 +359,7 @@ func TestValidateRequestSignature(t *testing.T) {

req.ContentLength = int64(len(payload) - 1)
req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))
signature, err := NewRequestSignature(sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, bytes.NewBuffer(payload))
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, bytes.NewBuffer(payload))
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand Down Expand Up @@ -392,7 +392,7 @@ func TestValidateRequestSignature(t *testing.T) {

datetime := time.Now().Format(time.RFC3339)
req.Header.Add(headers.RequestDate.String(), datetime)
signature, err := NewRequestSignature(sha256.New, "badtoken", datetime, http.MethodGet, req.RequestURI, nil)
signature, err := NewRequestSignature(context.Background(), sha256.New, "badtoken", datetime, http.MethodGet, req.RequestURI, nil)
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand Down Expand Up @@ -420,7 +420,7 @@ func TestValidateRequestSignature(t *testing.T) {

datetime := time.Now().Format(time.RFC3339)
req.Header.Add(headers.RequestDate.String(), datetime)
signature, err := NewRequestSignature(sha256.New, "token", datetime, req.Method, req.RequestURI, nil)
signature, err := NewRequestSignature(context.Background(), sha256.New, "token", datetime, req.Method, req.RequestURI, nil)
require.NoError(t, err)
req.Header.Add(headers.Signature.String(), base64.StdEncoding.EncodeToString(signature))

Expand Down
64 changes: 64 additions & 0 deletions cmd/api/src/api/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package api_test

import (
"context"
"crypto/sha256"
"github.com/specterops/bloodhound/headers"
"github.com/specterops/bloodhound/src/api"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"net/http"
"testing"
"time"
)

func Test_NewRequestSignature(t *testing.T) {
t.Run("returns error on context timeout", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

req, err := http.NewRequest(http.MethodGet, "http://teapotsrus.dev", nil)
require.NoError(t, err)

req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))

goCtx, cancel := context.WithDeadline(context.Background(), time.Now())
defer cancel()
time.Sleep(1 * time.Microsecond)
_, err = api.NewRequestSignature(goCtx, sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "context deadline exceeded")
})

t.Run("returns error on empty hmac signature", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

req, err := http.NewRequest(http.MethodGet, "http://teapotsrus.dev", nil)
require.NoError(t, err)

req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))

goCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cancel()
_, err = api.NewRequestSignature(goCtx, nil, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "hasher must not be nil")
})

t.Run("success", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

req, err := http.NewRequest(http.MethodGet, "http://teapotsrus.dev", nil)
require.NoError(t, err)

req.Header.Add(headers.RequestDate.String(), time.Now().Format(time.RFC3339))

goCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cancel()
signature, err := api.NewRequestSignature(goCtx, sha256.New, "token", time.Now().Format(time.RFC3339), req.Method, req.RequestURI, nil)
require.Nil(t, err)
require.NotEmpty(t, signature)
})
}
18 changes: 12 additions & 6 deletions cmd/api/src/api/middleware/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ import (
"github.com/specterops/bloodhound/log"
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/config"
"github.com/specterops/bloodhound/src/ctx"
"github.com/specterops/bloodhound/src/database"
)

// PanicHandler is a middleware func that sets up a defer-recovery trap to capture any unhandled panics that bubble
Expand Down Expand Up @@ -115,13 +113,13 @@ func setSignedRequestFields(request *http.Request, logEvent log.Event) {

// LoggingMiddleware is a middleware func that outputs a log for each request-response lifecycle. It includes timestamped
// information organized into fields suitable for searching or parsing.
func LoggingMiddleware(cfg config.Configuration, idResolver auth.IdentityResolver, db *database.BloodhoundDB) func(http.Handler) http.Handler {
func LoggingMiddleware(idResolver auth.IdentityResolver) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
var (
logEvent = log.WithLevel(log.LevelInfo)
requestContext = ctx.FromRequest(request)
deadline = time.Now().Add(time.Duration(cfg.NetTimeoutSeconds) * time.Second)
deadline time.Time

loggedResponse = &responseRecorder{
delegate: response,
Expand All @@ -133,17 +131,25 @@ func LoggingMiddleware(cfg config.Configuration, idResolver auth.IdentityResolve
}
)

// assign a deadline, but only if a valid timeout has been supplied via the prefer header
timeout, err := RequestWaitDuration(request)
if err != nil {
log.Errorf("Error parsing prefer header for timeout: %w", err)
} else if err == nil && timeout > 0 {
deadline = time.Now().Add(timeout * time.Second)
}

// Wrap the request body so that we can tell how much was read
request.Body = loggedRequestBody

// Defer the log statement and then serve the request
defer func() {
logEvent.Msgf("%s %s", request.Method, request.URL.RequestURI())

if time.Now().After(deadline) {
if !deadline.IsZero() && time.Now().After(deadline) {
log.Warnf(
"%s %s took longer than the configured timeout of %d seconds",
request.Method, request.URL.RequestURI(), cfg.NetTimeoutSeconds,
request.Method, request.URL.RequestURI(), timeout.Seconds(),
)
}
}()
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/api/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func getScheme(request *http.Request) string {
}
}

func requestWaitDuration(request *http.Request) (time.Duration, error) {
func RequestWaitDuration(request *http.Request) (time.Duration, error) {
var (
requestedWaitDuration time.Duration
err error
Expand Down Expand Up @@ -114,7 +114,7 @@ func ContextMiddleware(next http.Handler) http.Handler {
requestID = newUUID.String()
}

if requestedWaitDuration, err := requestWaitDuration(request); err != nil {
if requestedWaitDuration, err := RequestWaitDuration(request); err != nil {
// If there is a failure or other expectation mismatch with the client, respond right away with the relevant
// error information
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("Prefer header has an invalid value: %v", err), request), response)
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/api/middleware/middleware_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestRequestWaitDuration_Failure(t *testing.T) {
req.Header.Set(headers.Prefer.String(), "wait=1.5")
req.URL.RawQuery = q.Encode()

_, err = requestWaitDuration(req)
_, err = RequestWaitDuration(req)
require.NotNil(t, err)
}

Expand All @@ -74,7 +74,7 @@ func TestRequestWaitDuration(t *testing.T) {
req.Header.Set(headers.Prefer.String(), "wait=1")
req.URL.RawQuery = q.Encode()

requestedWaitDuration, err := requestWaitDuration(req)
requestedWaitDuration, err := RequestWaitDuration(req)
require.Nil(t, err)
require.Equal(t, 1*time.Second, requestedWaitDuration)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/src/api/registration/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func RegisterFossGlobalMiddleware(routerInst *router.Router, cfg config.Configur

// Set up logging. This must be done after ContextMiddleware is initialized so the context can be accessed in the log logic
if cfg.EnableAPILogging {
routerInst.UsePrerouting(middleware.LoggingMiddleware(cfg, identityResolver, db))
routerInst.UsePrerouting(middleware.LoggingMiddleware(identityResolver))
}

routerInst.UsePostrouting(
Expand Down
29 changes: 21 additions & 8 deletions cmd/api/src/api/signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package api

import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
Expand All @@ -35,7 +36,7 @@ const ErrorTemplateHMACSignature string = "unable to compute hmac signature: %w"

// tee takes a source reader and two writers. The function reads from the source until exhaustion. Each read is written
// serially to both writers.
func tee(reader io.Reader, outA, outB io.Writer) error {
func tee(ctx context.Context, reader io.Reader, outA, outB io.Writer) error {
// Ignore readers that are nil to begin with. This covers the case where a request is being signed but contains
// no body.
if reader == nil {
Expand All @@ -44,16 +45,18 @@ func tee(reader io.Reader, outA, outB io.Writer) error {

// Internal read buffer for splitting out to the other writers
buffer := make([]byte, 4096)
outputs := io.MultiWriter(outA, outB)

for {
read, err := reader.Read(buffer)

if read > 0 {
if _, err := outA.Write(buffer[:read]); err != nil {
return err
}
// check context after read
if err := ctx.Err(); err != nil {
return err
}

if _, err := outB.Write(buffer[:read]); err != nil {
if read > 0 {
if _, err := outputs.Write(buffer[:read]); err != nil {
return err
}
}
Expand All @@ -65,6 +68,11 @@ func tee(reader io.Reader, outA, outB io.Writer) error {

return nil
}

// check context after writes before next read
if err := ctx.Err(); err != nil {
return err
}
}
}

Expand Down Expand Up @@ -128,7 +136,7 @@ func (s *SelfDestructingTempFile) Name() string {

// NewRequestSignature generates the BloodHound request signature using the provided hash function.
// NOTE: The given io.Reader will be read to EOF. Consider using io.TeeReader so that the body may be read again after the signature has been created.
func NewRequestSignature(hasher func() hash.Hash, key string, datetime string, requestMethod string, requestURI string, body io.Reader) ([]byte, error) {
func NewRequestSignature(ctx context.Context, hasher func() hash.Hash, key string, datetime string, requestMethod string, requestURI string, body io.Reader) ([]byte, error) {
if hasher == nil {
return nil, fmt.Errorf(ErrorTemplateHMACSignature, fmt.Errorf("hasher must not be nil"))
}
Expand Down Expand Up @@ -164,6 +172,11 @@ func NewRequestSignature(hasher func() hash.Hash, key string, datetime string, r
// digester.
digester = hmac.New(hasher, digester.Sum(nil))

// check context before processing body
if err := ctx.Err(); err != nil {
return nil, err
}

if body != nil {
if _, err := io.Copy(digester, body); err != nil {
return nil, fmt.Errorf(ErrorTemplateHMACSignature, err)
Expand All @@ -186,7 +199,7 @@ func SignRequestAtTime(hasher func() hash.Hash, id string, token string, datetim
tee = io.TeeReader(request.Body, &buffer)
}

if signature, err := NewRequestSignature(hasher, token, datetimeFormatted, request.Method, request.URL.Path, tee); err != nil {
if signature, err := NewRequestSignature(request.Context(), hasher, token, datetimeFormatted, request.Method, request.URL.Path, tee); err != nil {
return err
} else {
// Overwrite the request body reader if the request body wasn't nil
Expand Down
Loading

0 comments on commit 96380dd

Please sign in to comment.