Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert.ErrorAs: log target type #1345

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2102,7 +2102,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)

return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+
Expand All @@ -2125,7 +2125,7 @@ func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)

return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
Expand All @@ -2143,11 +2143,11 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{
return true
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, true)

return Fail(t, fmt.Sprintf("Should be in error chain:\n"+
"expected: %q\n"+
"in chain: %s", target, chain,
"expected: %s\n"+
"in chain: %s", reflect.ValueOf(target).Elem().Type(), chain,
), msgAndArgs...)
}

Expand All @@ -2161,24 +2161,46 @@ func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interfa
return true
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, true)

return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
"in chain: %s", target, chain,
"found: %s\n"+
"in chain: %s", reflect.ValueOf(target).Elem().Type(), chain,
), msgAndArgs...)
}

func buildErrorChainString(err error) string {
func unwrapAll(err error) (errs []error) {
errs = append(errs, err)
craig65535 marked this conversation as resolved.
Show resolved Hide resolved
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, unwrapAll(err)...)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
}
return
}

func buildErrorChainString(err error, withType bool) string {
if err == nil {
return ""
}

e := errors.Unwrap(err)
chain := fmt.Sprintf("%q", err.Error())
for e != nil {
chain += fmt.Sprintf("\n\t%q", e.Error())
e = errors.Unwrap(e)
var chain string
errs := unwrapAll(err)
craig65535 marked this conversation as resolved.
Show resolved Hide resolved
for i := range errs {
if i != 0 {
chain += "\n\t"
}
brackendawson marked this conversation as resolved.
Show resolved Hide resolved
chain += fmt.Sprintf("%q", errs[i].Error())
if withType {
chain += fmt.Sprintf(" (%T)", errs[i])
}
}
return chain
}
88 changes: 63 additions & 25 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3175,11 +3175,13 @@ func parseLabeledOutput(output string) []labeledContent {
}

type captureTestingT struct {
msg string
failed bool
msg string
}

func (ctt *captureTestingT) Errorf(format string, args ...interface{}) {
ctt.msg = fmt.Sprintf(format, args...)
ctt.failed = true
}

func (ctt *captureTestingT) checkResultAndErrMsg(t *testing.T, expectedRes, res bool, expectedErrMsg string) {
Expand All @@ -3188,6 +3190,10 @@ func (ctt *captureTestingT) checkResultAndErrMsg(t *testing.T, expectedRes, res
t.Errorf("Should return %t", expectedRes)
return
}
if res == ctt.failed {
t.Errorf("The test result (%t) should be reflected in the testing.T type (%t)", res, !ctt.failed)
return
}
contents := parseLabeledOutput(ctt.msg)
if res == true {
if contents != nil {
Expand Down Expand Up @@ -3348,50 +3354,82 @@ func TestNotErrorIs(t *testing.T) {

func TestErrorAs(t *testing.T) {
tests := []struct {
err error
result bool
err error
result bool
resultErrMsg string
}{
{fmt.Errorf("wrap: %w", &customError{}), true},
{io.EOF, false},
{nil, false},
{
err: fmt.Errorf("wrap: %w", &customError{}),
result: true,
},
{
err: io.EOF,
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: *assert.customError\n" +
"in chain: \"EOF\" (*errors.errorString)\n",
},
{
err: nil,
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: *assert.customError\n" +
"in chain: \n",
},
{
err: fmt.Errorf("abc: %w", errors.New("def")),
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: *assert.customError\n" +
"in chain: \"abc: def\" (*fmt.wrapError)\n" +
"\t\"def\" (*errors.errorString)\n",
},
}
for _, tt := range tests {
tt := tt
var target *customError
t.Run(fmt.Sprintf("ErrorAs(%#v,%#v)", tt.err, target), func(t *testing.T) {
mockT := new(testing.T)
mockT := new(captureTestingT)
res := ErrorAs(mockT, tt.err, &target)
if res != tt.result {
t.Errorf("ErrorAs(%#v,%#v) should return %t", tt.err, target, tt.result)
}
if res == mockT.Failed() {
t.Errorf("The test result (%t) should be reflected in the testing.T type (%t)", res, !mockT.Failed())
}
mockT.checkResultAndErrMsg(t, tt.result, res, tt.resultErrMsg)
})
}
}

func TestNotErrorAs(t *testing.T) {
tests := []struct {
err error
result bool
err error
result bool
resultErrMsg string
}{
{fmt.Errorf("wrap: %w", &customError{}), false},
{io.EOF, true},
{nil, true},
{
err: fmt.Errorf("wrap: %w", &customError{}),
result: false,
resultErrMsg: "" +
"Target error should not be in err chain:\n" +
"found: *assert.customError\n" +
"in chain: \"wrap: fail\" (*fmt.wrapError)\n" +
"\t\"fail\" (*assert.customError)\n",
},
{
err: io.EOF,
result: true,
},
{
err: nil,
result: true,
},
}
for _, tt := range tests {
tt := tt
var target *customError
t.Run(fmt.Sprintf("NotErrorAs(%#v,%#v)", tt.err, target), func(t *testing.T) {
mockT := new(testing.T)
mockT := new(captureTestingT)
res := NotErrorAs(mockT, tt.err, &target)
if res != tt.result {
t.Errorf("NotErrorAs(%#v,%#v) should not return %t", tt.err, target, tt.result)
}
if res == mockT.Failed() {
t.Errorf("The test result (%t) should be reflected in the testing.T type (%t)", res, !mockT.Failed())
}
mockT.checkResultAndErrMsg(t, tt.result, res, tt.resultErrMsg)
})
}
}
Loading