diff --git a/messages.go b/messages.go index 0a9c2f4..109fd32 100644 --- a/messages.go +++ b/messages.go @@ -194,7 +194,7 @@ func (msg *Message) From() string { } func (msg *Message) IsTransactionPart() bool { - return msg.Type.SupportsTransaction() && len(msg.TransactionUUID) > 0 + return msg.Type.RequiresTransaction() && len(msg.TransactionUUID) > 0 } func (msg *Message) TransactionKey() string { diff --git a/messages_test.go b/messages_test.go index a5f88b1..0401766 100644 --- a/messages_test.go +++ b/messages_test.go @@ -114,7 +114,7 @@ func testMessageRoutable(t *testing.T, original Message) { assert.Equal(original.Source, original.From()) assert.Equal(original.TransactionUUID, original.TransactionKey()) assert.Equal( - original.Type.SupportsTransaction() && len(original.TransactionUUID) > 0, + original.Type.RequiresTransaction() && len(original.TransactionUUID) > 0, original.IsTransactionPart(), ) diff --git a/messagetype.go b/messagetype.go index 36163ae..5feaf04 100644 --- a/messagetype.go +++ b/messagetype.go @@ -44,10 +44,10 @@ const ( lastMessageType ) -// SupportsTransaction tests if messages of this type are allowed to participate in transactions. +// RequiresTransaction tests if messages of this type are allowed to participate in transactions. // If this method returns false, the TransactionUUID field should be ignored (but passed through -// where applicable). -func (mt MessageType) SupportsTransaction() bool { +// where applicable). If this method returns true, TransactionUUID must be included in request. +func (mt MessageType) RequiresTransaction() bool { switch mt { case SimpleRequestResponseMessageType, CreateMessageType, RetrieveMessageType, UpdateMessageType, DeleteMessageType: return true diff --git a/messagetype_test.go b/messagetype_test.go index 068ed45..4090395 100644 --- a/messagetype_test.go +++ b/messagetype_test.go @@ -79,7 +79,7 @@ func TestMessageTypeSupportsTransaction(t *testing.T) { ) for messageType, expected := range expectedSupportsTransaction { - assert.Equal(expected, messageType.SupportsTransaction()) + assert.Equal(expected, messageType.RequiresTransaction()) } } diff --git a/wrphttp/handler.go b/wrphttp/handler.go index 6248a78..33b726d 100644 --- a/wrphttp/handler.go +++ b/wrphttp/handler.go @@ -18,6 +18,7 @@ package wrphttp import ( + "fmt" "net/http" gokithttp "github.com/go-kit/kit/transport/http" @@ -128,6 +129,15 @@ func (wh *wrpHandler) ServeHTTP(httpResponse http.ResponseWriter, httpRequest *h return } + if entity.Message.Type.RequiresTransaction() && entity.Message.TransactionUUID == "" { + wrappedErr := httpError{ + err: fmt.Errorf("%s", string(entity.Bytes)), + code: http.StatusBadRequest, + } + wh.errorEncoder(ctx, wrappedErr, httpResponse) + return + } + for _, mf := range wh.before { ctx = mf(ctx, &entity.Message) } diff --git a/wrphttp/handler_test.go b/wrphttp/handler_test.go index 56fd9d6..5c5a772 100644 --- a/wrphttp/handler_test.go +++ b/wrphttp/handler_test.go @@ -19,6 +19,7 @@ package wrphttp import ( "context" + "encoding/json" "errors" "fmt" "net/http" @@ -270,6 +271,58 @@ func testWRPHandlerDecodeError(t *testing.T) { wrpHandler.AssertExpectations(t) } +func testTransactionUUIDError(t *testing.T) { + var ( + assert = assert.New(t) + require = require.New(t) + + msg = wrp.Message{ + Type: wrp.SimpleRequestResponseMessageType, + } + msgBytes, _ = json.Marshal(msg) + entity = &Entity{ + Message: msg, + Bytes: msgBytes, + } + + expectedError = httpError{ + err: fmt.Errorf("%s", string(entity.Bytes)), + code: http.StatusBadRequest, + } + decoder = func(_ context.Context, _ *http.Request) (*Entity, error) { + return entity, nil + } + + errorEncoderCalled = false + httpResponse = httptest.NewRecorder() + ) + + httpRequest := httptest.NewRequest("POST", "/", nil) + + errorEncoder := func(_ context.Context, actualErr error, _ http.ResponseWriter) { + errorEncoderCalled = true + assert.Equal(expectedError, actualErr) + + var actualErrorHTTP httpError + if assert.ErrorAs(actualErr, &actualErrorHTTP, + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + actualErr, actualErrorHTTP)) { + assert.Equal(expectedError.code, actualErrorHTTP.code) + } + } + + wrpHandler := new(MockHandler) + httpHandler := NewHTTPHandler(wrpHandler, + WithDecoder(decoder), + WithErrorEncoder(errorEncoder)) + + require.NotNil(httpHandler) + httpHandler.ServeHTTP(httpResponse, httpRequest) + + assert.True(errorEncoderCalled) + +} + func testWRPHandlerResponseWriterError(t *testing.T) { var ( assert = assert.New(t) @@ -386,5 +439,6 @@ func TestWRPHandler(t *testing.T) { t.Run("DecodeError", testWRPHandlerDecodeError) t.Run("ResponseWriterError", testWRPHandlerResponseWriterError) t.Run("Success", testWRPHandlerSuccess) + t.Run("TransactionUUIDError", testTransactionUUIDError) }) }