Skip to content

Commit

Permalink
Merge pull request #134 from xmidt-org/tid-validation
Browse files Browse the repository at this point in the history
added validation for transaction uuid
  • Loading branch information
schmidtw authored Sep 22, 2023
2 parents 3855006 + fc2dc49 commit 166b282
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 6 deletions.
2 changes: 1 addition & 1 deletion messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (msg *Message) From() string {
}

func (msg *Message) IsTransactionPart() bool {
return msg.Type.SupportsTransaction() && len(msg.TransactionUUID) > 0
return msg.Type.RequiresTransaction() && len(msg.TransactionUUID) > 0
}

func (msg *Message) TransactionKey() string {
Expand Down
2 changes: 1 addition & 1 deletion messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func testMessageRoutable(t *testing.T, original Message) {
assert.Equal(original.Source, original.From())
assert.Equal(original.TransactionUUID, original.TransactionKey())
assert.Equal(
original.Type.SupportsTransaction() && len(original.TransactionUUID) > 0,
original.Type.RequiresTransaction() && len(original.TransactionUUID) > 0,
original.IsTransactionPart(),
)

Expand Down
6 changes: 3 additions & 3 deletions messagetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ const (
lastMessageType
)

// SupportsTransaction tests if messages of this type are allowed to participate in transactions.
// RequiresTransaction tests if messages of this type are allowed to participate in transactions.
// If this method returns false, the TransactionUUID field should be ignored (but passed through
// where applicable).
func (mt MessageType) SupportsTransaction() bool {
// where applicable). If this method returns true, TransactionUUID must be included in request.
func (mt MessageType) RequiresTransaction() bool {
switch mt {
case SimpleRequestResponseMessageType, CreateMessageType, RetrieveMessageType, UpdateMessageType, DeleteMessageType:
return true
Expand Down
2 changes: 1 addition & 1 deletion messagetype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestMessageTypeSupportsTransaction(t *testing.T) {
)

for messageType, expected := range expectedSupportsTransaction {
assert.Equal(expected, messageType.SupportsTransaction())
assert.Equal(expected, messageType.RequiresTransaction())
}
}

Expand Down
10 changes: 10 additions & 0 deletions wrphttp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package wrphttp

import (
"fmt"
"net/http"

gokithttp "github.com/go-kit/kit/transport/http"
Expand Down Expand Up @@ -128,6 +129,15 @@ func (wh *wrpHandler) ServeHTTP(httpResponse http.ResponseWriter, httpRequest *h
return
}

if entity.Message.Type.RequiresTransaction() && entity.Message.TransactionUUID == "" {
wrappedErr := httpError{
err: fmt.Errorf("%s", string(entity.Bytes)),
code: http.StatusBadRequest,
}
wh.errorEncoder(ctx, wrappedErr, httpResponse)
return
}

for _, mf := range wh.before {
ctx = mf(ctx, &entity.Message)
}
Expand Down
54 changes: 54 additions & 0 deletions wrphttp/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package wrphttp

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -270,6 +271,58 @@ func testWRPHandlerDecodeError(t *testing.T) {
wrpHandler.AssertExpectations(t)
}

func testTransactionUUIDError(t *testing.T) {
var (
assert = assert.New(t)
require = require.New(t)

msg = wrp.Message{
Type: wrp.SimpleRequestResponseMessageType,
}
msgBytes, _ = json.Marshal(msg)
entity = &Entity{
Message: msg,
Bytes: msgBytes,
}

expectedError = httpError{
err: fmt.Errorf("%s", string(entity.Bytes)),
code: http.StatusBadRequest,
}
decoder = func(_ context.Context, _ *http.Request) (*Entity, error) {
return entity, nil
}

errorEncoderCalled = false
httpResponse = httptest.NewRecorder()
)

httpRequest := httptest.NewRequest("POST", "/", nil)

errorEncoder := func(_ context.Context, actualErr error, _ http.ResponseWriter) {
errorEncoderCalled = true
assert.Equal(expectedError, actualErr)

var actualErrorHTTP httpError
if assert.ErrorAs(actualErr, &actualErrorHTTP,
fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain",
actualErr, actualErrorHTTP)) {
assert.Equal(expectedError.code, actualErrorHTTP.code)
}
}

wrpHandler := new(MockHandler)
httpHandler := NewHTTPHandler(wrpHandler,
WithDecoder(decoder),
WithErrorEncoder(errorEncoder))

require.NotNil(httpHandler)
httpHandler.ServeHTTP(httpResponse, httpRequest)

assert.True(errorEncoderCalled)

}

func testWRPHandlerResponseWriterError(t *testing.T) {
var (
assert = assert.New(t)
Expand Down Expand Up @@ -386,5 +439,6 @@ func TestWRPHandler(t *testing.T) {
t.Run("DecodeError", testWRPHandlerDecodeError)
t.Run("ResponseWriterError", testWRPHandlerResponseWriterError)
t.Run("Success", testWRPHandlerSuccess)
t.Run("TransactionUUIDError", testTransactionUUIDError)
})
}

0 comments on commit 166b282

Please sign in to comment.