Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2603 [master] - Revised error handling using Go 1.13 error APIs #1474

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions benchmark/harness_case.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package benchmark

import (
"context"
"errors"
"fmt"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -95,12 +96,12 @@ benchRepeat:
res.Duration = c.elapsed
c.cumulativeRuntime += res.Duration

switch res.Error {
case context.DeadlineExceeded:
switch {
case errors.Is(res.Error, context.DeadlineExceeded):
break benchRepeat
case context.Canceled:
case errors.Is(res.Error, context.Canceled):
break benchRepeat
case nil:
case res.Error == nil:
out.Trials++
c.elapsed = 0
out.Raw = append(out.Raw, res)
Expand Down
8 changes: 4 additions & 4 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func newDefaultStructCodec() *StructCodec {
if err != nil {
// This function is called from the codec registration path, so errors can't be propagated. If there's an error
// constructing the StructCodec, we panic to avoid losing it.
panic(fmt.Errorf("error creating default StructCodec: %v", err))
panic(fmt.Errorf("error creating default StructCodec: %w", err))
}
return codec
}
Expand Down Expand Up @@ -178,7 +178,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe

for {
key, elemVr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
} else if err != nil {
return err
Expand Down Expand Up @@ -1379,7 +1379,7 @@ func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr bsonrw.Value
keyType := val.Type().Key()
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
}
if err != nil {
Expand Down Expand Up @@ -1675,7 +1675,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
idx := 0
for {
vr, err := ar.ReadValue()
if err == bsonrw.ErrEOA {
if errors.Is(err, bsonrw.ErrEOA) {
break
}
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions bson/bsoncodec/default_value_decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2370,8 +2370,8 @@ func TestDefaultValueDecoders(t *testing.T) {
return
}
if rc.val == cansettest { // We're doing an IsValid and CanSet test
wanterr, ok := rc.err.(ValueDecoderError)
if !ok {
var wanterr ValueDecoderError
if !errors.As(rc.err, &wanterr) {
t.Fatalf("Error must be a DecodeValueError, but got a %T", rc.err)
}

Expand Down Expand Up @@ -3685,8 +3685,8 @@ func TestDefaultValueDecoders(t *testing.T) {
val := reflect.New(reflect.TypeOf(outer{})).Elem()
err := defaultTestStructCodec.DecodeValue(dc, vr, val)

decodeErr, ok := err.(*DecodeError)
assert.True(t, ok, "expected DecodeError, got %v of type %T", err, err)
var decodeErr *DecodeError
assert.True(t, errors.As(err, &decodeErr), "expected DecodeError, got %v of type %T", err, err)
expectedKeys := []string{"foo", "bar"}
assert.Equal(t, expectedKeys, decodeErr.Keys(), "expected keys slice %v, got %v", expectedKeys,
decodeErr.Keys())
Expand Down
6 changes: 3 additions & 3 deletions bson/bsoncodec/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum
}

currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && lookupErr != errInvalidValue {
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}

Expand Down Expand Up @@ -418,7 +418,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val

for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}

Expand Down Expand Up @@ -487,7 +487,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val

for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}

Expand Down
2 changes: 1 addition & 1 deletion bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value,
if mc.EncodeKeysWithStringer {
parsed, err := strconv.ParseFloat(key, 64)
if err != nil {
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
}
keyVal = reflect.ValueOf(parsed)
break
Expand Down
4 changes: 2 additions & 2 deletions bson/bsoncodec/struct_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val
}

func newDecodeError(key string, original error) error {
de, ok := original.(*DecodeError)
if !ok {
var de *DecodeError
if !errors.As(original, &de) {
return &DecodeError{
keys: []string{key},
wrapped: original,
Expand Down
3 changes: 2 additions & 1 deletion bson/decoder_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bson_test

import (
"bytes"
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -200,7 +201,7 @@ func ExampleDecoder_multipleExtendedJSONDocuments() {
for {
var res Coordinate
err = decoder.Decode(&res)
if err == io.EOF {
if errors.Is(err, io.EOF) {
break
}
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion bson/encoder_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bson_test

import (
"bytes"
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -162,7 +163,7 @@ func ExampleEncoder_multipleBSONDocuments() {
// Extended JSON by converting them to bson.Raw.
for {
doc, err := bson.ReadDocument(buf)
if err == io.EOF {
if errors.Is(err, io.EOF) {
return
}
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions bson/primitive/objectid.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func processUniqueBytes() [5]byte {
var b [5]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}

return b
Expand All @@ -193,7 +193,7 @@ func readRandomUint32() uint32 {
var b [4]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}

return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
Expand Down
9 changes: 5 additions & 4 deletions cmd/testatlas/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"time"
Expand Down Expand Up @@ -52,7 +53,7 @@ func main() {
func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
client, err := mongo.Connect(ctx, clientOpts)
if err != nil {
return fmt.Errorf("Connect error: %v", err)
return fmt.Errorf("Connect error: %w", err)
}

defer func() {
Expand All @@ -63,12 +64,12 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
cmd := bson.D{{handshake.LegacyHello, 1}}
err = db.RunCommand(ctx, cmd).Err()
if err != nil {
return fmt.Errorf("legacy hello error: %v", err)
return fmt.Errorf("legacy hello error: %w", err)
}

coll := db.Collection("test")
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && err != mongo.ErrNoDocuments {
return fmt.Errorf("FindOne error: %v", err)
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
return fmt.Errorf("FindOne error: %w", err)
}
return nil
}
3 changes: 2 additions & 1 deletion cmd/testaws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package main

import (
"context"
"errors"
"fmt"
"os"

Expand All @@ -33,7 +34,7 @@ func main() {

db := client.Database("aws")
coll := db.Collection("test")
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && err != mongo.ErrNoDocuments {
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
panic(fmt.Sprintf("FindOne error: %v", err))
}
}
4 changes: 3 additions & 1 deletion internal/aws/awserr/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package awserr

import (
"errors"
"fmt"
)

Expand Down Expand Up @@ -106,7 +107,8 @@ func (b baseError) OrigErr() error {
case 1:
return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
var err Error
if errors.As(b.errs[0], &err) {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
Expand Down
3 changes: 2 additions & 1 deletion internal/csfle/csfle.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package csfle

import (
"errors"
"fmt"

"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
Expand All @@ -23,7 +24,7 @@ func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionNam
fieldName := stateCollection + "Collection"
val, err := efBSON.LookupErr(fieldName)
if err != nil {
if err != bsoncore.ErrElementNotFound {
if !errors.Is(err, bsoncore.ErrElementNotFound) {
return "", err
}
// Return default name.
Expand Down
12 changes: 6 additions & 6 deletions mongo/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool {
ctx = context.Background()
}
doc, err := c.batch.Next()
switch err {
case nil:
switch {
case err == nil:
// Consume the next document in the current batch.
c.batchLength--
c.Current = bson.Raw(doc)
return true
case io.EOF: // Need to do a getMore
case errors.Is(err, io.EOF): // Need to do a getMore
default:
c.err = err
return false
Expand Down Expand Up @@ -204,12 +204,12 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool {
c.batch = c.bc.Batch()
c.batchLength = c.batch.DocumentCount()
doc, err = c.batch.Next()
switch err {
case nil:
switch {
case err == nil:
c.batchLength--
c.Current = bson.Raw(doc)
return true
case io.EOF: // Empty batch so we continue
case errors.Is(err, io.EOF): // Empty batch so we continue
default:
c.err = err
return false
Expand Down
21 changes: 5 additions & 16 deletions mongo/integration/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package integration
import (
"context"
"errors"
"fmt"
"io"
"net"
"testing"
Expand Down Expand Up @@ -45,18 +46,6 @@ func (n netErr) Temporary() bool {

var _ net.Error = (*netErr)(nil)

type wrappedError struct {
err error
}

func (we wrappedError) Error() string {
return we.err.Error()
}

func (we wrappedError) Unwrap() error {
return we.err
}

func TestErrors(t *testing.T) {
mt := mtest.New(t, noClientOpts)

Expand Down Expand Up @@ -478,7 +467,7 @@ func TestErrors(t *testing.T) {
},
false,
},
{"wrapped error", wrappedError{mongo.CommandError{11000, "", nil, "blah", nil, nil}}, true},
{"wrapped error", fmt.Errorf("%w", mongo.CommandError{11000, "", nil, "blah", nil, nil}), true},
{"other error type", errors.New("foo"), false},
}
for _, tc := range testCases {
Expand All @@ -499,7 +488,7 @@ func TestErrors(t *testing.T) {
}{
{"ServerError true", mongo.CommandError{100, "", []string{networkLabel}, "blah", nil, nil}, true},
{"ServerError false", mongo.CommandError{100, "", []string{otherLabel}, "blah", nil, nil}, false},
{"wrapped error", wrappedError{mongo.CommandError{100, "", []string{networkLabel}, "blah", nil, nil}}, true},
{"wrapped error", fmt.Errorf("%w", mongo.CommandError{100, "", []string{networkLabel}, "blah", nil, nil}), true},
{"other error type", errors.New("foo"), false},
}
for _, tc := range testCases {
Expand Down Expand Up @@ -533,8 +522,8 @@ func TestErrors(t *testing.T) {
{"net error true", mongo.CommandError{
100, "", []string{"other"}, "blah", netErr{true}, nil}, true},
{"net error false", netErr{false}, false},
{"wrapped error", wrappedError{mongo.CommandError{
100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}}, true},
{"wrapped error", fmt.Errorf("%w", mongo.CommandError{
100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}), true},
{"other error", errors.New("foo"), false},
}
for _, tc := range testCases {
Expand Down
4 changes: 2 additions & 2 deletions mongo/integration/mtest/global_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func ServerVersion() string {
func SetFailPoint(fp FailPoint, client *mongo.Client) error {
admin := client.Database("admin")
if err := admin.RunCommand(context.Background(), fp).Err(); err != nil {
return fmt.Errorf("error creating fail point: %v", err)
return fmt.Errorf("error creating fail point: %w", err)
}
return nil
}
Expand All @@ -89,7 +89,7 @@ func SetFailPoint(fp FailPoint, client *mongo.Client) error {
func SetRawFailPoint(fp bson.Raw, client *mongo.Client) error {
admin := client.Database("admin")
if err := admin.RunCommand(context.Background(), fp).Err(); err != nil {
return fmt.Errorf("error creating fail point: %v", err)
return fmt.Errorf("error creating fail point: %w", err)
}
return nil
}
Loading