Skip to content

Commit

Permalink
Merge pull request #96 from go-spectest/refactor-20231015
Browse files Browse the repository at this point in the history
Refactor mock
  • Loading branch information
nao1215 authored Oct 15, 2023
2 parents a3407e3 + 9590606 commit 0cbb597
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 108 deletions.
3 changes: 1 addition & 2 deletions assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ func (a DefaultVerifier) Equal(t TestingT, expected, actual interface{}, msgAndA
expected, actual, err), msgAndArgs...)
}

// For non-error values, continue with the existing comparison logic
if !objectsAreEqual(expected, actual) {
diff := diff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return a.Fail(t, fmt.Sprintf("Not equal: \n"+
"expected: %s\n"+
"actual : %s%s", expected, actual, diff), msgAndArgs...)
}

return true
}

Expand Down Expand Up @@ -121,7 +121,6 @@ func (a DefaultVerifier) NoError(t TestingT, err error, msgAndArgs ...interface{
if err != nil {
return a.Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...)
}

return true
}

Expand Down
5 changes: 0 additions & 5 deletions debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ func (d *debug) enable() {
d.enabled = true
}

// disable will disable debug logging
func (d *debug) disable() {
d.enabled = false
}

// isEnable returns true if debug logging is enabled
func (d *debug) isEnable() bool {
return d.enabled
Expand Down
67 changes: 67 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package spectest

import (
"errors"
"fmt"
"sort"
"strings"
)

var (
// ErrTimeout is an error that indicates a timeout.
ErrTimeout = errors.New("deadline exceeded")
)

// errorOrNil returns nil if the statement is true, otherwise it returns an error with the given message.
func errorOrNil(statement bool, errorMessage func() string) error {
if statement {
return nil
}
return errors.New(errorMessage())
}

// unmatchedMockError is used to store errors when a request does not match any mocks.
// It implements the error interface.
type unmatchedMockError struct {
// errors is a map of mock number to errors
errors map[int][]error
}

// newUnmatchedMockError creates a new unmatchedMockError
func newUnmatchedMockError() *unmatchedMockError {
return &unmatchedMockError{
errors: map[int][]error{},
}
}

// append adds errors to the unmatchedMockError
func (u *unmatchedMockError) append(mockNumber int, errors ...error) *unmatchedMockError {
u.errors[mockNumber] = append(u.errors[mockNumber], errors...)
return u
}

// Error implementation of in-built error human readable string function
func (u *unmatchedMockError) Error() string {
var strBuilder strings.Builder
strBuilder.WriteString("received request did not match any mocks\n\n")
for _, mockNumber := range u.orderedMockKeys() {
strBuilder.WriteString(fmt.Sprintf("Mock %d mismatches:\n", mockNumber))
for _, err := range u.errors[mockNumber] {
strBuilder.WriteString("• ")
strBuilder.WriteString(err.Error())
strBuilder.WriteString("\n")
}
strBuilder.WriteString("\n")
}
return strBuilder.String()
}

// orderedMockKeys returns the keys of the errors map in order.
func (u *unmatchedMockError) orderedMockKeys() []int {
mockKeys := make([]int, 0, len(u.errors))
for mockKey := range u.errors {
mockKeys = append(mockKeys, mockKey)
}
sort.Ints(mockKeys)
return mockKeys
}
142 changes: 53 additions & 89 deletions mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"path/filepath"
"reflect"
"regexp"
"sort"
"strings"
"sync"
"time"
Expand All @@ -22,52 +21,6 @@ import (
difflib "github.com/go-spectest/diff"
)

// unmatchedMockError is used to store errors when a request does not match any mocks.
// It implements the error interface.
type unmatchedMockError struct {
// errors is a map of mock number to errors
errors map[int][]error
}

// newUnmatchedMockError creates a new unmatchedMockError
func newUnmatchedMockError() *unmatchedMockError {
return &unmatchedMockError{
errors: map[int][]error{},
}
}

// addErrors adds errors to the unmatchedMockError
func (u *unmatchedMockError) addErrors(mockNumber int, errors ...error) *unmatchedMockError {
u.errors[mockNumber] = append(u.errors[mockNumber], errors...)
return u
}

// Error implementation of in-built error human readable string function
func (u *unmatchedMockError) Error() string {
var strBuilder strings.Builder
strBuilder.WriteString("received request did not match any mocks\n\n")
for _, mockNumber := range u.orderedMockKeys() {
strBuilder.WriteString(fmt.Sprintf("Mock %d mismatches:\n", mockNumber))
for _, err := range u.errors[mockNumber] {
strBuilder.WriteString("• ")
strBuilder.WriteString(err.Error())
strBuilder.WriteString("\n")
}
strBuilder.WriteString("\n")
}
return strBuilder.String()
}

// orderedMockKeys returns the keys of the errors map in order.
func (u *unmatchedMockError) orderedMockKeys() []int {
var mockKeys []int
for mockKey := range u.errors {
mockKeys = append(mockKeys, mockKey)
}
sort.Ints(mockKeys)
return mockKeys
}

// Transport wraps components used to observe and manipulate the real request and response objects
type Transport struct {
// httpClient is the http client used when networking is enabled.
Expand Down Expand Up @@ -115,7 +68,7 @@ func newTransport(

// RoundTrip implementation intended to match a given expected mock request
// or throw an error with a list of reasons why no match was found.
func (r *Transport) RoundTrip(req *http.Request) (mockResponse *http.Response, matchErrors error) {
func (r *Transport) RoundTrip(req *http.Request) (mockResponse *http.Response, err error) {
defer func() {
r.debug.mock(mockResponse, req)
}()
Expand All @@ -128,26 +81,24 @@ func (r *Transport) RoundTrip(req *http.Request) (mockResponse *http.Response, m
}()
}

matchedResponse, matchErrors := matches(req, r.mocks)
if matchErrors == nil {
res := buildResponseFromMock(matchedResponse)
res.Request = req

if matchedResponse.timeout {
return nil, timeoutError{}
matchedResponse, err := matches(req, r.mocks)
if err != nil {
if r.debug.isEnable() {
fmt.Printf("failed to match mocks. Errors: %s\n", err)
}
return nil, err
}

if r.mockResponseDelayEnabled && matchedResponse.fixedDelayMillis > 0 {
time.Sleep(time.Duration(matchedResponse.fixedDelayMillis) * time.Millisecond)
}
res := buildResponseFromMock(matchedResponse)
res.Request = req

return res, nil
if matchedResponse.timeout {
return nil, ErrTimeout
}

if r.debug.isEnable() {
fmt.Printf("failed to match mocks. Errors: %s\n", matchErrors)
if r.mockResponseDelayEnabled && matchedResponse.fixedDelayMillis > 0 {
time.Sleep(time.Duration(matchedResponse.fixedDelayMillis) * time.Millisecond)
}
return nil, matchErrors
return res, nil
}

// Hijack replace the transport implementation of the interaction under test in order to observe, mock and inject expectations
Expand All @@ -168,15 +119,15 @@ func (r *Transport) Reset() {
http.DefaultTransport = r.nativeTransport
}

// buildResponseFromMock builds a http.Response from a MockResponse
func buildResponseFromMock(mockResponse *MockResponse) *http.Response {
if mockResponse == nil {
return nil
}

// if the content type isn't set and the body contains json, set content type as json
contentTypeHeader := mockResponse.headers["Content-Type"]
var contentType string

// if the content type isn't set and the body contains json, set content type as json
if len(mockResponse.body) > 0 {
if len(contentTypeHeader) == 0 {
if json.Valid([]byte(mockResponse.body)) {
Expand All @@ -197,24 +148,22 @@ func buildResponseFromMock(mockResponse *MockResponse) *http.Response {
ProtoMinor: 1,
ContentLength: int64(len(mockResponse.body)),
}

for _, cookie := range mockResponse.cookies {
if v := cookie.ToHTTPCookie().String(); v != "" {
res.Header.Add("Set-Cookie", v)
}
}

if contentType != "" {
res.Header.Set("Content-Type", contentType)
}

return res
}

// Mock represents the entire interaction for a mock to be used for testing
type Mock struct {
m *sync.Mutex
isUsed bool
m *sync.Mutex
// state is mock runnig state
state *state
// request is used to configure the request of the mock
request *MockRequest
// resopnse is used to configure the response of the mock
Expand All @@ -238,8 +187,8 @@ func (m *Mock) Matches(req *http.Request) []error {
return errs
}

// copy copy Mock.
func (m *Mock) copy() *Mock {
// deepCopy deepCopy Mock.
func (m *Mock) deepCopy() *Mock {
newMock := *m

newMock.m = &sync.Mutex{}
Expand All @@ -250,6 +199,9 @@ func (m *Mock) copy() *Mock {
res := *m.response
newMock.response = &res

state := *m.state
newMock.state = &state

return &newMock
}

Expand Down Expand Up @@ -368,6 +320,7 @@ func NewMock() *Mock {
mock := &Mock{
debugStandalone: newDebug(),
m: &sync.Mutex{},
state: newState(),
execCount: newExecCount(1),
}
mock.request = newMockRequest(mock)
Expand Down Expand Up @@ -511,23 +464,24 @@ func (m *Mock) parseURL(u string) {
m.request.url = parsed
}

// matches checks whether the given request matches any of the given mocks
func matches(req *http.Request, mocks []*Mock) (*MockResponse, error) {
mockError := newUnmatchedMockError()
for mockNumber, mock := range mocks {
mock.m.Lock() // lock is for isUsed when matches is called concurrently by RoundTripper
if mock.isUsed {
if mock.state.isRunning() {
mock.m.Unlock()
continue
}

errs := mock.Matches(req)
if len(errs) == 0 {
mock.isUsed = true
mock.state.Start()
mock.m.Unlock()
return mock.response, nil
}

mockError = mockError.addErrors(mockNumber+1, errs...)
mockError = mockError.append(mockNumber+1, errs...)
mock.m.Unlock()
}

Expand Down Expand Up @@ -1242,20 +1196,6 @@ func bodyRegexpMatcher(req *http.Request, spec *MockRequest) error {
return fmt.Errorf("received body did not match expected mock body\n%s", diff(expression, bodyStr))
}

// errorOrNil returns nil if the statement is true, otherwise it returns an error with the given message.
func errorOrNil(statement bool, errorMessage func() string) error {
if statement {
return nil
}
return errors.New(errorMessage())
}

type timeoutError struct{}

func (timeoutError) Error() string { return "deadline exceeded" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }

var spewConfig = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
Expand Down Expand Up @@ -1349,3 +1289,27 @@ func (e *execCount) updateExpectCount(expect uint) {
func (e *execCount) isComplete() bool {
return e.actual == e.expect
}

// state is used to track the state of a mock. It's very simple state machine
type state struct {
running bool
}

func newState() *state {
return &state{}
}

// Start sets the state to running.
func (s *state) Start() {
s.running = true
}

// Stop sets the state to not running.
func (s *state) Stop() {
s.running = false
}

// isRunning returns true if the state is running.
func (s *state) isRunning() bool {
return s.running
}
Loading

0 comments on commit 0cbb597

Please sign in to comment.