diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 97bf1299a5b..dcd0071fa41 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2011,7 +2011,7 @@ axes: - id: "windows-64-go-1-20" display_name: "Windows 64-bit" run_on: - - windows-vsCurrent-latest-small + - windows-vsCurrent-small variables: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" @@ -2035,7 +2035,7 @@ axes: - id: "windows-64-go-1-20" display_name: "Windows 64-bit" run_on: - - windows-vsCurrent-latest-small + - windows-vsCurrent-small variables: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" @@ -2067,7 +2067,7 @@ axes: - id: "windows-64-vsCurrent-latest-small-go-1-20" display_name: "Windows 64-bit" run_on: - - windows-vsCurrent-latest-small + - windows-vsCurrent-small variables: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" @@ -2388,23 +2388,48 @@ buildvariants: tasks: - name: "test-docker-runner" - - matrix_name: "tests-36-with-zlib-support" + - matrix_name: "tests-rhel-36-with-zlib-support" tags: ["pullrequest"] - matrix_spec: { version: ["3.6"], os-ssl-32: ["windows-64-go-1-20", "rhel87-64-go-1-20"] } + matrix_spec: { version: ["3.6"], os-ssl-32: ["rhel87-64-go-1-20"] } display_name: "${version} ${os-ssl-32}" tasks: - name: ".test !.enterprise-auth !.snappy !.zstd" - - matrix_name: "tests-40-with-zlib-support" + - matrix_name: "tests-windows-36-with-zlib-support" + matrix_spec: { version: ["3.6"], os-ssl-32: ["windows-64-go-1-20"] } + display_name: "${version} ${os-ssl-32}" + tasks: + - name: ".test !.enterprise-auth !.snappy !.zstd" + + - matrix_name: "tests-rhel-40-with-zlib-support" tags: ["pullrequest"] - matrix_spec: { version: ["4.0"], os-ssl-40: ["windows-64-go-1-20", "rhel87-64-go-1-20"] } + matrix_spec: { version: ["4.0"], os-ssl-40: ["rhel87-64-go-1-20"] } display_name: "${version} ${os-ssl-40}" tasks: - name: ".test !.enterprise-auth !.snappy !.zstd" - - matrix_name: "tests-42-plus-zlib-zstd-support" + - matrix_name: "tests-windows-40-with-zlib-support" + matrix_spec: { version: ["4.0"], os-ssl-40: ["windows-64-go-1-20"] } + display_name: "${version} ${os-ssl-40}" + tasks: + - name: ".test !.enterprise-auth !.snappy !.zstd" + + - matrix_name: "tests-rhel-42-plus-zlib-zstd-support" + tags: ["pullrequest"] + matrix_spec: { version: ["4.2", "4.4", "5.0", "6.0", "7.0", "8.0"], os-ssl-40: ["rhel87-64-go-1-20"] } + display_name: "${version} ${os-ssl-40}" + tasks: + - name: ".test !.enterprise-auth !.snappy" + + - matrix_name: "tests-windows-42-plus-zlib-zstd-support" + matrix_spec: { version: ["4.2", "4.4", "5.0", "6.0", "7.0"], os-ssl-40: ["windows-64-go-1-20"] } + display_name: "${version} ${os-ssl-40}" + tasks: + - name: ".test !.enterprise-auth !.snappy" + + - matrix_name: "tests-windows-80-zlib-zstd-support" tags: ["pullrequest"] - matrix_spec: { version: ["4.2", "4.4", "5.0", "6.0", "7.0", "8.0"], os-ssl-40: ["windows-64-go-1-20", "rhel87-64-go-1-20"] } + matrix_spec: { version: ["8.0"], os-ssl-40: ["windows-64-go-1-20"] } display_name: "${version} ${os-ssl-40}" tasks: - name: ".test !.enterprise-auth !.snappy" diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 044e1743df5..dd93c729a57 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -15,27 +15,13 @@ on: jobs: analyze: - name: Analyze (${{ matrix.language }}) - runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} - timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + name: Analyze (go) + runs-on: 'ubuntu-latest' + timeout-minutes: 360 permissions: # required for all workflows security-events: write - # required to fetch internal or private CodeQL packs - packages: read - - # only required for workflows in private repositories - actions: read - contents: read - - strategy: - fail-fast: false - matrix: - include: - - language: go - build-mode: manual - steps: - name: Checkout repository uses: actions/checkout@v4 @@ -44,15 +30,14 @@ jobs: - name: Initialize CodeQL uses: github/codeql-action/init@v3 with: - languages: ${{ matrix.language }} - build-mode: ${{ matrix.build-mode }} + languages: go + build-mode: manual - - if: matrix.build-mode == 'manual' - shell: bash + - shell: bash run: | make build - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v3 with: - category: "/language:${{matrix.language}}" + category: "/language:go" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000000..f9a2ac10d42 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,88 @@ +name: Release + +on: + workflow_dispatch: + inputs: + version: + description: "The new version to set" + required: true + prev_version: + description: "The previous tagged version" + required: true + push_changes: + description: "Push changes?" + default: true + type: boolean + +defaults: + run: + shell: bash -eux {0} + +env: + # Changes per branch + SILK_ASSET_GROUP: mongodb-go-driver + EVERGREEN_PROJECT: mongo-go-driver + +jobs: + pre-publish: + environment: release + runs-on: ubuntu-latest + permissions: + id-token: write + contents: write + outputs: + prev_version: ${{ steps.pre-publish.outputs.prev_version }} + steps: + - uses: mongodb-labs/drivers-github-tools/secure-checkout@v2 + with: + app_id: ${{ vars.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + - uses: mongodb-labs/drivers-github-tools/setup@v2 + with: + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + aws_region_name: ${{ vars.AWS_REGION_NAME }} + aws_secret_id: ${{ secrets.AWS_SECRET_ID }} + artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }} + - name: Pre Publish + id: pre-publish + uses: mongodb-labs/drivers-github-tools/golang/pre-publish@v2 + with: + version: ${{ inputs.version }} + push_changes: ${{ inputs.push_changes }} + + static-scan: + needs: [pre-publish] + permissions: + security-events: write + uses: ./.github/workflows/codeql.yml + with: + ref: ${{ github.ref }} + + publish: + needs: [pre-publish, static-scan] + runs-on: ubuntu-latest + environment: release + permissions: + id-token: write + contents: write + security-events: read + steps: + - uses: mongodb-labs/drivers-github-tools/secure-checkout@v2 + with: + app_id: ${{ vars.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + - uses: mongodb-labs/drivers-github-tools/setup@v2 + with: + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + aws_region_name: ${{ vars.AWS_REGION_NAME }} + aws_secret_id: ${{ secrets.AWS_SECRET_ID }} + artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }} + - name: Publish + uses: mongodb-labs/drivers-github-tools/golang/publish@v2 + with: + version: ${{ inputs.version }} + silk_asset_group: ${{ env.SILK_ASSET_GROUP }} + evergreen_project: ${{ env.EVERGREEN_PROJECT }} + prev_version: ${{ inputs.prev_version }} + push_changes: ${{ inputs.push_changes }} + token: ${{ env.GH_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ab05fafebe5..c4e5498fbf9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,5 +17,5 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - - uses: pre-commit/action@v3.0.0 + - uses: actions/setup-python@v5 + - uses: pre-commit/action@v3.0.1 diff --git a/bson/array_codec.go b/bson/array_codec.go index 5b07f4acd43..4a53d376bcc 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -12,24 +12,11 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -// ArrayCodec is the Codec used for bsoncore.Array values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ArrayCodec registered. -type ArrayCodec struct{} - -var defaultArrayCodec = NewArrayCodec() - -// NewArrayCodec returns an ArrayCodec. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ArrayCodec registered. -func NewArrayCodec() *ArrayCodec { - return &ArrayCodec{} -} +// arrayCodec is the Codec used for bsoncore.Array values. +type arrayCodec struct{} // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } @@ -39,7 +26,7 @@ func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.V } // DecodeValue is the ValueDecoder for bsoncore.Array values. -func (ac *ArrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bson_test.go b/bson/bson_test.go index 78fd4986c58..cb926838f0a 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -17,7 +17,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -297,6 +296,30 @@ func TestD(t *testing.T) { }) } +func TestDStringer(t *testing.T) { + got := D{{"a", 1}, {"b", 2}}.String() + want := `{"a":{"$numberInt":"1"},"b":{"$numberInt":"2"}}` + assert.Equal(t, want, got, "expected: %s, got: %s", want, got) +} + +func TestMStringer(t *testing.T) { + type msg struct { + A json.RawMessage `json:"a"` + B json.RawMessage `json:"b"` + } + + var res msg + got := M{"a": 1, "b": 2}.String() + err := json.Unmarshal([]byte(got), &res) + require.NoError(t, err, "Unmarshal error") + + want := msg{ + A: json.RawMessage(`{"$numberInt":"1"}`), + B: json.RawMessage(`{"$numberInt":"2"}`), + } + + assert.Equal(t, want, res, "returned string did not unmarshal to the expected document, returned string: %s", got) +} func TestD_MarshalJSON(t *testing.T) { t.Parallel() @@ -521,19 +544,18 @@ func TestMapCodec(t *testing.T) { strstr := stringerString("foo") mapObj := map[stringerString]int{strstr: 1} testCases := []struct { - name string - opts *bsonoptions.MapCodecOptions - key string + name string + mapCodec *mapCodec + key string }{ - {"default", bsonoptions.MapCodec(), "foo"}, - {"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"}, - {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, + {"default", &mapCodec{}, "foo"}, + {"true", &mapCodec{encodeKeysWithStringer: true}, "bar"}, + {"false", &mapCodec{encodeKeysWithStringer: false}, "foo"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mapCodec := NewMapCodec(tc.opts) mapRegistry := NewRegistry() - mapRegistry.RegisterKindEncoder(reflect.Map, mapCodec) + mapRegistry.RegisterKindEncoder(reflect.Map, tc.mapCodec) buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoder(vw) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index ad1d4a8dedc..10d42647c80 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -165,18 +165,10 @@ type DecodeContext struct { // Deprecated: Use bson.Decoder.AllowTruncatingDoubles instead. Truncate bool - // Ancestor is the type of a containing document. This is mainly used to determine what type - // should be used when decoding an embedded document into an empty interface. For example, if - // Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface - // will be decoded into a bson.M. - // - // Deprecated: Use bson.Decoder.DefaultDocumentM or bson.Decoder.DefaultDocumentD instead. - Ancestor reflect.Type - // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an - // error. DocumentType overrides the Ancestor field. + // error. defaultDocumentType reflect.Type binaryAsSlice bool @@ -234,14 +226,6 @@ func (dc *DecodeContext) DefaultDocumentM() { dc.defaultDocumentType = reflect.TypeOf(M{}) } -// DefaultDocumentD causes the Decoder to always unmarshal documents into the D type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentD] instead. -func (dc *DecodeContext) DefaultDocumentD() { - dc.defaultDocumentType = reflect.TypeOf(D{}) -} - // ValueCodec is an interface for encoding and decoding a reflect.Value. // values. // @@ -329,13 +313,3 @@ func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader err := vd.DecodeValue(dc, vr, val) return val, err } - -// CodecZeroer is the interface implemented by Codecs that can also determine if -// a value of the type that would be encoded is zero. -// -// Deprecated: Defining custom rules for the zero/empty value will not be supported in Go Driver -// 2.0. Users who want to omit empty complex values should use a pointer field and set the value to -// nil instead. -type CodecZeroer interface { - IsTypeZero(interface{}) bool -} diff --git a/bson/bsoncodec_test.go b/bson/bsoncodec_test.go index d1dc21a953d..61c38933eed 100644 --- a/bson/bsoncodec_test.go +++ b/bson/bsoncodec_test.go @@ -13,7 +13,7 @@ import ( ) func ExampleValueEncoder() { - var _ ValueEncoderFunc = func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + var _ ValueEncoderFunc = func(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} } diff --git a/bson/bsonoptions/byte_slice_codec_options.go b/bson/bsonoptions/byte_slice_codec_options.go deleted file mode 100644 index 996bd17127a..00000000000 --- a/bson/bsonoptions/byte_slice_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// ByteSliceCodecOptions represents all possible options for byte slice encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type ByteSliceCodecOptions struct { - EncodeNilAsEmpty *bool // Specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. -} - -// ByteSliceCodec creates a new *ByteSliceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func ByteSliceCodec() *ByteSliceCodecOptions { - return &ByteSliceCodecOptions{} -} - -// SetEncodeNilAsEmpty specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. -func (bs *ByteSliceCodecOptions) SetEncodeNilAsEmpty(b bool) *ByteSliceCodecOptions { - bs.EncodeNilAsEmpty = &b - return bs -} - -// MergeByteSliceCodecOptions combines the given *ByteSliceCodecOptions into a single *ByteSliceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeByteSliceCodecOptions(opts ...*ByteSliceCodecOptions) *ByteSliceCodecOptions { - bs := ByteSliceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeNilAsEmpty != nil { - bs.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - } - - return bs -} diff --git a/bson/bsonoptions/empty_interface_codec_options.go b/bson/bsonoptions/empty_interface_codec_options.go deleted file mode 100644 index f522c7e03fe..00000000000 --- a/bson/bsonoptions/empty_interface_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// EmptyInterfaceCodecOptions represents all possible options for interface{} encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type EmptyInterfaceCodecOptions struct { - DecodeBinaryAsSlice *bool // Specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. -} - -// EmptyInterfaceCodec creates a new *EmptyInterfaceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func EmptyInterfaceCodec() *EmptyInterfaceCodecOptions { - return &EmptyInterfaceCodecOptions{} -} - -// SetDecodeBinaryAsSlice specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. -func (e *EmptyInterfaceCodecOptions) SetDecodeBinaryAsSlice(b bool) *EmptyInterfaceCodecOptions { - e.DecodeBinaryAsSlice = &b - return e -} - -// MergeEmptyInterfaceCodecOptions combines the given *EmptyInterfaceCodecOptions into a single *EmptyInterfaceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeEmptyInterfaceCodecOptions(opts ...*EmptyInterfaceCodecOptions) *EmptyInterfaceCodecOptions { - e := EmptyInterfaceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeBinaryAsSlice != nil { - e.DecodeBinaryAsSlice = opt.DecodeBinaryAsSlice - } - } - - return e -} diff --git a/bson/bsonoptions/map_codec_options.go b/bson/bsonoptions/map_codec_options.go deleted file mode 100644 index a7a7c1d9804..00000000000 --- a/bson/bsonoptions/map_codec_options.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// MapCodecOptions represents all possible options for map encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type MapCodecOptions struct { - DecodeZerosMap *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false. - EncodeNilAsEmpty *bool // Specifies if a nil map should encode as an empty document instead of null. Defaults to false. - // Specifies how keys should be handled. If false, the behavior matches encoding/json, where the encoding key type must - // either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key type must either be a - // string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with fmt.Sprint() and the - // encoding key type must be a string, an integer type, or a float. If true, the use of Stringer will override - // TextMarshaler/TextUnmarshaler. Defaults to false. - EncodeKeysWithStringer *bool -} - -// MapCodec creates a new *MapCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func MapCodec() *MapCodecOptions { - return &MapCodecOptions{} -} - -// SetDecodeZerosMap specifies if the map should be zeroed before decoding into it. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. -func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions { - t.DecodeZerosMap = &b - return t -} - -// SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. -func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions { - t.EncodeNilAsEmpty = &b - return t -} - -// SetEncodeKeysWithStringer specifies how keys should be handled. If false, the behavior matches encoding/json, where the -// encoding key type must either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key -// type must either be a string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with -// fmt.Sprint() and the encoding key type must be a string, an integer type, or a float. If true, the use of Stringer -// will override TextMarshaler/TextUnmarshaler. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. -func (t *MapCodecOptions) SetEncodeKeysWithStringer(b bool) *MapCodecOptions { - t.EncodeKeysWithStringer = &b - return t -} - -// MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions { - s := MapCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeZerosMap != nil { - s.DecodeZerosMap = opt.DecodeZerosMap - } - if opt.EncodeNilAsEmpty != nil { - s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - if opt.EncodeKeysWithStringer != nil { - s.EncodeKeysWithStringer = opt.EncodeKeysWithStringer - } - } - - return s -} diff --git a/bson/bsonoptions/slice_codec_options.go b/bson/bsonoptions/slice_codec_options.go deleted file mode 100644 index 3c1e4f35ba1..00000000000 --- a/bson/bsonoptions/slice_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// SliceCodecOptions represents all possible options for slice encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type SliceCodecOptions struct { - EncodeNilAsEmpty *bool // Specifies if a nil slice should encode as an empty array instead of null. Defaults to false. -} - -// SliceCodec creates a new *SliceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func SliceCodec() *SliceCodecOptions { - return &SliceCodecOptions{} -} - -// SetEncodeNilAsEmpty specifies if a nil slice should encode as an empty array instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. -func (s *SliceCodecOptions) SetEncodeNilAsEmpty(b bool) *SliceCodecOptions { - s.EncodeNilAsEmpty = &b - return s -} - -// MergeSliceCodecOptions combines the given *SliceCodecOptions into a single *SliceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeSliceCodecOptions(opts ...*SliceCodecOptions) *SliceCodecOptions { - s := SliceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeNilAsEmpty != nil { - s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - } - - return s -} diff --git a/bson/bsonoptions/string_codec_options.go b/bson/bsonoptions/string_codec_options.go deleted file mode 100644 index f8b76f996e4..00000000000 --- a/bson/bsonoptions/string_codec_options.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -var defaultDecodeOIDAsHex = true - -// StringCodecOptions represents all possible options for string encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type StringCodecOptions struct { - DecodeObjectIDAsHex *bool // Specifies if we should decode ObjectID as the hex value. Defaults to true. -} - -// StringCodec creates a new *StringCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func StringCodec() *StringCodecOptions { - return &StringCodecOptions{} -} - -// SetDecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. If false, a string made -// from the raw object ID bytes will be used. Defaults to true. -// -// Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. -func (t *StringCodecOptions) SetDecodeObjectIDAsHex(b bool) *StringCodecOptions { - t.DecodeObjectIDAsHex = &b - return t -} - -// MergeStringCodecOptions combines the given *StringCodecOptions into a single *StringCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeStringCodecOptions(opts ...*StringCodecOptions) *StringCodecOptions { - s := &StringCodecOptions{&defaultDecodeOIDAsHex} - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeObjectIDAsHex != nil { - s.DecodeObjectIDAsHex = opt.DecodeObjectIDAsHex - } - } - - return s -} diff --git a/bson/bsonoptions/struct_codec_options.go b/bson/bsonoptions/struct_codec_options.go deleted file mode 100644 index 1cbfa32e8b4..00000000000 --- a/bson/bsonoptions/struct_codec_options.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -var defaultOverwriteDuplicatedInlinedFields = true - -// StructCodecOptions represents all possible options for struct encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type StructCodecOptions struct { - DecodeZeroStruct *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false. - DecodeDeepZeroInline *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false. - EncodeOmitDefaultStruct *bool // Specifies if default structs should be considered empty by omitempty. Defaults to false. - AllowUnexportedFields *bool // Specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. - OverwriteDuplicatedInlinedFields *bool // Specifies if fields in inlined structs can be overwritten by higher level struct fields with the same key. Defaults to true. -} - -// StructCodec creates a new *StructCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func StructCodec() *StructCodecOptions { - return &StructCodecOptions{} -} - -// SetDecodeZeroStruct specifies if structs should be zeroed before decoding into them. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. -func (t *StructCodecOptions) SetDecodeZeroStruct(b bool) *StructCodecOptions { - t.DecodeZeroStruct = &b - return t -} - -// SetDecodeDeepZeroInline specifies if structs should be zeroed before decoding into them. Defaults to false. -// -// Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. -func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions { - t.DecodeDeepZeroInline = &b - return t -} - -// SetEncodeOmitDefaultStruct specifies if default structs should be considered empty by omitempty. A default struct has all -// its values set to their default value. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. -func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOptions { - t.EncodeOmitDefaultStruct = &b - return t -} - -// SetOverwriteDuplicatedInlinedFields specifies if inlined struct fields can be overwritten by higher level struct fields with the -// same bson key. When true and decoding, values will be written to the outermost struct with a matching key, and when -// encoding, keys will have the value of the top-most matching field. When false, decoding and encoding will error if -// there are duplicate keys after the struct is inlined. Defaults to true. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. -func (t *StructCodecOptions) SetOverwriteDuplicatedInlinedFields(b bool) *StructCodecOptions { - t.OverwriteDuplicatedInlinedFields = &b - return t -} - -// SetAllowUnexportedFields specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. -// -// Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be -// supported in Go Driver 2.0. -func (t *StructCodecOptions) SetAllowUnexportedFields(b bool) *StructCodecOptions { - t.AllowUnexportedFields = &b - return t -} - -// MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions { - s := &StructCodecOptions{ - OverwriteDuplicatedInlinedFields: &defaultOverwriteDuplicatedInlinedFields, - } - for _, opt := range opts { - if opt == nil { - continue - } - - if opt.DecodeZeroStruct != nil { - s.DecodeZeroStruct = opt.DecodeZeroStruct - } - if opt.DecodeDeepZeroInline != nil { - s.DecodeDeepZeroInline = opt.DecodeDeepZeroInline - } - if opt.EncodeOmitDefaultStruct != nil { - s.EncodeOmitDefaultStruct = opt.EncodeOmitDefaultStruct - } - if opt.OverwriteDuplicatedInlinedFields != nil { - s.OverwriteDuplicatedInlinedFields = opt.OverwriteDuplicatedInlinedFields - } - if opt.AllowUnexportedFields != nil { - s.AllowUnexportedFields = opt.AllowUnexportedFields - } - } - - return s -} diff --git a/bson/bsonoptions/time_codec_options.go b/bson/bsonoptions/time_codec_options.go deleted file mode 100644 index 3f38433d226..00000000000 --- a/bson/bsonoptions/time_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// TimeCodecOptions represents all possible options for time.Time encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type TimeCodecOptions struct { - UseLocalTimeZone *bool // Specifies if we should decode into the local time zone. Defaults to false. -} - -// TimeCodec creates a new *TimeCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func TimeCodec() *TimeCodecOptions { - return &TimeCodecOptions{} -} - -// SetUseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. -func (t *TimeCodecOptions) SetUseLocalTimeZone(b bool) *TimeCodecOptions { - t.UseLocalTimeZone = &b - return t -} - -// MergeTimeCodecOptions combines the given *TimeCodecOptions into a single *TimeCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeTimeCodecOptions(opts ...*TimeCodecOptions) *TimeCodecOptions { - t := TimeCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.UseLocalTimeZone != nil { - t.UseLocalTimeZone = opt.UseLocalTimeZone - } - } - - return t -} diff --git a/bson/bsonoptions/uint_codec_options.go b/bson/bsonoptions/uint_codec_options.go deleted file mode 100644 index 5091e4d9633..00000000000 --- a/bson/bsonoptions/uint_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// UIntCodecOptions represents all possible options for uint encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type UIntCodecOptions struct { - EncodeToMinSize *bool // Specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. -} - -// UIntCodec creates a new *UIntCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func UIntCodec() *UIntCodecOptions { - return &UIntCodecOptions{} -} - -// SetEncodeToMinSize specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.IntMinSize] instead. -func (u *UIntCodecOptions) SetEncodeToMinSize(b bool) *UIntCodecOptions { - u.EncodeToMinSize = &b - return u -} - -// MergeUIntCodecOptions combines the given *UIntCodecOptions into a single *UIntCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeUIntCodecOptions(opts ...*UIntCodecOptions) *UIntCodecOptions { - u := UIntCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeToMinSize != nil { - u.EncodeToMinSize = opt.EncodeToMinSize - } - } - - return u -} diff --git a/bson/bsonrw_test.go b/bson/bsonrw_test.go index f37eb0142ed..297d1b6c0bb 100644 --- a/bson/bsonrw_test.go +++ b/bson/bsonrw_test.go @@ -12,8 +12,10 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var _ ValueReader = (*valueReaderWriter)(nil) -var _ ValueWriter = (*valueReaderWriter)(nil) +var ( + _ ValueReader = &valueReaderWriter{} + _ ValueWriter = &valueReaderWriter{} +) // invoked is a type used to indicate what method was called last. type invoked byte diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index 586c006467e..bd44cf9a899 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -9,56 +9,32 @@ package bson import ( "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// ByteSliceCodec is the Codec used for []byte values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ByteSliceCodec registered. -type ByteSliceCodec struct { - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values +// byteSliceCodec is the Codec used for []byte values. +type byteSliceCodec struct { + // encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. - // - // Deprecated: Use bson.Encoder.NilByteSliceAsEmpty instead. - EncodeNilAsEmpty bool + encodeNilAsEmpty bool } -var ( - defaultByteSliceCodec = NewByteSliceCodec() - - // Assert that defaultByteSliceCodec satisfies the typeDecoder interface, which allows it to be - // used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultByteSliceCodec -) - -// NewByteSliceCodec returns a ByteSliceCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ByteSliceCodec registered. -func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec { - byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...) - codec := ByteSliceCodec{} - if byteSliceOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *byteSliceOpt.EncodeNilAsEmpty - } - return &codec -} +// Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be +// used by collection type decoders (e.g. map, slice, etc) to set individual values in a +// collection. +var _ typeDecoder = &byteSliceCodec{} // EncodeValue is the ValueEncoder for []byte. -func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - if val.IsNil() && !bsc.EncodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { return vw.WriteNull() } return vw.WriteBinary(val.Interface().([]byte)) } -func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ Name: "ByteSliceDecodeValue", @@ -106,7 +82,7 @@ func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect } // DecodeValue is the ValueDecoder for []byte. -func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tByteSlice { return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index fba139ff075..012b2d825cd 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -16,7 +16,7 @@ type condAddrEncoder struct { elseEnc ValueEncoder } -var _ ValueEncoder = (*condAddrEncoder)(nil) +var _ ValueEncoder = &condAddrEncoder{} // newCondAddrEncoder returns an condAddrEncoder. func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder { @@ -41,7 +41,7 @@ type condAddrDecoder struct { elseDec ValueDecoder } -var _ ValueDecoder = (*condAddrDecoder)(nil) +var _ ValueDecoder = &condAddrDecoder{} // newCondAddrDecoder returns an CondAddrDecoder. func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder { diff --git a/bson/copier.go b/bson/copier.go index abdd7162e48..07ebc744b54 100644 --- a/bson/copier.go +++ b/bson/copier.go @@ -162,8 +162,8 @@ func copyDocumentToBytes(src ValueReader) ([]byte, error) { // Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be // supported in Go Driver 2.0. func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) { - if br, ok := src.(BytesReader); ok { - _, dst, err := br.ReadValueBytes(dst) + if br, ok := src.(bytesReader); ok { + _, dst, err := br.readValueBytes(dst) return dst, err } @@ -182,8 +182,8 @@ func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) { // Deprecated: Copying BSON arrays using the ValueWriter and ValueReader interfaces will not be // supported in Go Driver 2.0. func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { - if br, ok := src.(BytesReader); ok { - _, dst, err := br.ReadValueBytes(dst) + if br, ok := src.(bytesReader); ok { + _, dst, err := br.readValueBytes(dst) return dst, err } @@ -201,8 +201,8 @@ func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { // // Deprecated: Use [go.mongodb.org/mongo-driver/bson.UnmarshalValue] instead. func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error { - if wvb, ok := dst.(BytesWriter); ok { - return wvb.WriteValueBytes(t, src) + if wvb, ok := dst.(bytesWriter); ok { + return wvb.writeValueBytes(t, src) } vr := vrPool.Get().(*valueReader) @@ -228,8 +228,8 @@ func CopyValueToBytes(src ValueReader) (Type, []byte, error) { // Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go // Driver 2.0. func appendValueBytes(dst []byte, src ValueReader) (Type, []byte, error) { - if br, ok := src.(BytesReader); ok { - return br.ReadValueBytes(dst) + if br, ok := src.(bytesReader); ok { + return br.readValueBytes(dst) } vw := vwPool.Get().(*valueWriter) diff --git a/bson/decoder.go b/bson/decoder.go index 6ea5ad97c13..1276987496d 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -31,10 +31,7 @@ type Decoder struct { dc DecodeContext vr ValueReader - // We persist defaultDocumentM and defaultDocumentD on the Decoder to prevent overwriting from - // (*Decoder).SetContext. defaultDocumentM bool - defaultDocumentD bool binaryAsSlice bool useJSONStructTags bool @@ -87,9 +84,6 @@ func (d *Decoder) Decode(val interface{}) error { if d.defaultDocumentM { d.dc.DefaultDocumentM() } - if d.defaultDocumentD { - d.dc.DefaultDocumentD() - } if d.binaryAsSlice { d.dc.BinaryAsSlice() } @@ -126,12 +120,6 @@ func (d *Decoder) DefaultDocumentM() { d.defaultDocumentM = true } -// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -func (d *Decoder) DefaultDocumentD() { - d.defaultDocumentD = true -} - // AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values // when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct // field. The truncation logic does not apply to BSON "decimal128" values. diff --git a/bson/decoder_example_test.go b/bson/decoder_example_test.go index 3e17e989278..f87a107b0b2 100644 --- a/bson/decoder_example_test.go +++ b/bson/decoder_example_test.go @@ -8,6 +8,7 @@ package bson_test import ( "bytes" + "encoding/json" "errors" "fmt" "io" @@ -85,8 +86,12 @@ func ExampleDecoder_DefaultDocumentM() { panic(err) } - fmt.Printf("%+v\n", res) - // Output: {Name:New York Properties:map[elevation:10 population:8804190 state:NY]} + data, err = json.Marshal(res) + if err != nil { + panic(err) + } + fmt.Printf("%+v\n", string(data)) + // Output: {"Name":"New York","Properties":{"elevation":10,"population":8804190,"state":"NY"}} } func ExampleDecoder_UseJSONStructTags() { diff --git a/bson/decoder_test.go b/bson/decoder_test.go index dbef3e7fb00..cfb51d64336 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -478,13 +478,11 @@ func TestDecoderConfiguration(t *testing.T) { decodeInto: func() interface{} { return &D{} }, want: &D{{Key: "myBinary", Value: []byte{}}}, }, - // Test that DefaultDocumentD overrides the default "ancestor" logic and always decodes BSON - // documents into bson.D values, independent of the top-level Go value type. + // Test that the default decoder always decodes BSON documents into bson.D values, + // independent of the top-level Go value type. { - description: "DefaultDocumentD nested", - configure: func(dec *Decoder) { - dec.DefaultDocumentD() - }, + description: "DocumentD nested by default", + configure: func(dec *Decoder) {}, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). AppendString("myString", "test value"). @@ -495,8 +493,8 @@ func TestDecoderConfiguration(t *testing.T) { "myDocument": D{{Key: "myString", Value: "test value"}}, }, }, - // Test that DefaultDocumentM overrides the default "ancestor" logic and always decodes BSON - // documents into bson.M values, independent of the top-level Go value type. + // Test that DefaultDocumentM always decodes BSON documents into bson.M values, + // independent of the top-level Go value type. { description: "DefaultDocumentM nested", configure: func(dec *Decoder) { @@ -614,7 +612,7 @@ func TestDecoderConfiguration(t *testing.T) { } assert.Equal(t, want, got, "expected and actual decode results do not match") }) - t.Run("DefaultDocumentD top-level", func(t *testing.T) { + t.Run("Default decodes DocumentD for top-level", func(t *testing.T) { t.Parallel() input := bsoncore.NewDocumentBuilder(). @@ -625,8 +623,6 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.DefaultDocumentD() - var got interface{} err := dec.Decode(&got) require.NoError(t, err, "Decode error") diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 3256f92089d..8f95ea2485d 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -18,10 +18,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var ( - defaultValueDecoders DefaultValueDecoders - errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") -) +var errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") type decodeBinaryError struct { subtype byte @@ -32,119 +29,92 @@ func (d decodeBinaryError) Error() string { return fmt.Sprintf("only binary values with subtype 0x00 or 0x02 can be decoded into %s, but got subtype %v", d.typeName, d.subtype) } -func newDefaultStructCodec() *StructCodec { - codec, err := NewStructCodec(DefaultStructTagParser) - if err != nil { - // This function is called from the codec registration path, so errors can't be propagated. If there's an error - // constructing the StructCodec, we panic to avoid losing it. - panic(fmt.Errorf("error creating default StructCodec: %w", err)) - } - return codec -} - -// DefaultValueDecoders is a namespace type for the default ValueDecoders used -// when creating a registry. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -type DefaultValueDecoders struct{} - -// RegisterDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with +// registerDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with // the provided RegistryBuilder. // // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { - if rb == nil { - panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) - } - - intDecoder := decodeAdapter{dvd.IntDecodeValue, dvd.intDecodeType} - floatDecoder := decodeAdapter{dvd.FloatDecodeValue, dvd.floatDecodeType} - - rb. - RegisterTypeDecoder(tD, ValueDecoderFunc(dvd.DDecodeValue)). - RegisterTypeDecoder(tBinary, decodeAdapter{dvd.BinaryDecodeValue, dvd.binaryDecodeType}). - RegisterTypeDecoder(tUndefined, decodeAdapter{dvd.UndefinedDecodeValue, dvd.undefinedDecodeType}). - RegisterTypeDecoder(tDateTime, decodeAdapter{dvd.DateTimeDecodeValue, dvd.dateTimeDecodeType}). - RegisterTypeDecoder(tNull, decodeAdapter{dvd.NullDecodeValue, dvd.nullDecodeType}). - RegisterTypeDecoder(tRegex, decodeAdapter{dvd.RegexDecodeValue, dvd.regexDecodeType}). - RegisterTypeDecoder(tDBPointer, decodeAdapter{dvd.DBPointerDecodeValue, dvd.dBPointerDecodeType}). - RegisterTypeDecoder(tTimestamp, decodeAdapter{dvd.TimestampDecodeValue, dvd.timestampDecodeType}). - RegisterTypeDecoder(tMinKey, decodeAdapter{dvd.MinKeyDecodeValue, dvd.minKeyDecodeType}). - RegisterTypeDecoder(tMaxKey, decodeAdapter{dvd.MaxKeyDecodeValue, dvd.maxKeyDecodeType}). - RegisterTypeDecoder(tJavaScript, decodeAdapter{dvd.JavaScriptDecodeValue, dvd.javaScriptDecodeType}). - RegisterTypeDecoder(tSymbol, decodeAdapter{dvd.SymbolDecodeValue, dvd.symbolDecodeType}). - RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec). - RegisterTypeDecoder(tTime, defaultTimeCodec). - RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeDecoder(tCoreArray, defaultArrayCodec). - RegisterTypeDecoder(tOID, decodeAdapter{dvd.ObjectIDDecodeValue, dvd.objectIDDecodeType}). - RegisterTypeDecoder(tDecimal, decodeAdapter{dvd.Decimal128DecodeValue, dvd.decimal128DecodeType}). - RegisterTypeDecoder(tJSONNumber, decodeAdapter{dvd.JSONNumberDecodeValue, dvd.jsonNumberDecodeType}). - RegisterTypeDecoder(tURL, decodeAdapter{dvd.URLDecodeValue, dvd.urlDecodeType}). - RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(dvd.CoreDocumentDecodeValue)). - RegisterTypeDecoder(tCodeWithScope, decodeAdapter{dvd.CodeWithScopeDecodeValue, dvd.codeWithScopeDecodeType}). - RegisterDefaultDecoder(reflect.Bool, decodeAdapter{dvd.BooleanDecodeValue, dvd.booleanDecodeType}). - RegisterDefaultDecoder(reflect.Int, intDecoder). - RegisterDefaultDecoder(reflect.Int8, intDecoder). - RegisterDefaultDecoder(reflect.Int16, intDecoder). - RegisterDefaultDecoder(reflect.Int32, intDecoder). - RegisterDefaultDecoder(reflect.Int64, intDecoder). - RegisterDefaultDecoder(reflect.Uint, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint8, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint16, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint32, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Float32, floatDecoder). - RegisterDefaultDecoder(reflect.Float64, floatDecoder). - RegisterDefaultDecoder(reflect.Array, ValueDecoderFunc(dvd.ArrayDecodeValue)). - RegisterDefaultDecoder(reflect.Map, defaultMapCodec). - RegisterDefaultDecoder(reflect.Slice, defaultSliceCodec). - RegisterDefaultDecoder(reflect.String, defaultStringCodec). - RegisterDefaultDecoder(reflect.Struct, newDefaultStructCodec()). - RegisterDefaultDecoder(reflect.Ptr, NewPointerCodec()). - RegisterTypeMapEntry(TypeDouble, tFloat64). - RegisterTypeMapEntry(TypeString, tString). - RegisterTypeMapEntry(TypeArray, tA). - RegisterTypeMapEntry(TypeBinary, tBinary). - RegisterTypeMapEntry(TypeUndefined, tUndefined). - RegisterTypeMapEntry(TypeObjectID, tOID). - RegisterTypeMapEntry(TypeBoolean, tBool). - RegisterTypeMapEntry(TypeDateTime, tDateTime). - RegisterTypeMapEntry(TypeRegex, tRegex). - RegisterTypeMapEntry(TypeDBPointer, tDBPointer). - RegisterTypeMapEntry(TypeJavaScript, tJavaScript). - RegisterTypeMapEntry(TypeSymbol, tSymbol). - RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope). - RegisterTypeMapEntry(TypeInt32, tInt32). - RegisterTypeMapEntry(TypeInt64, tInt64). - RegisterTypeMapEntry(TypeTimestamp, tTimestamp). - RegisterTypeMapEntry(TypeDecimal128, tDecimal). - RegisterTypeMapEntry(TypeMinKey, tMinKey). - RegisterTypeMapEntry(TypeMaxKey, tMaxKey). - RegisterTypeMapEntry(Type(0), tD). - RegisterTypeMapEntry(TypeEmbeddedDocument, tD). - RegisterHookDecoder(tValueUnmarshaler, ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue)). - RegisterHookDecoder(tUnmarshaler, ValueDecoderFunc(dvd.UnmarshalerDecodeValue)) +func registerDefaultDecoders(reg *Registry) { + intDecoder := decodeAdapter{intDecodeValue, intDecodeType} + floatDecoder := decodeAdapter{floatDecodeValue, floatDecodeType} + uintCodec := &uintCodec{} + + reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue)) + reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType}) + reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType}) + reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType}) + reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType}) + reg.RegisterTypeDecoder(tRegex, decodeAdapter{regexDecodeValue, regexDecodeType}) + reg.RegisterTypeDecoder(tDBPointer, decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType}) + reg.RegisterTypeDecoder(tTimestamp, decodeAdapter{timestampDecodeValue, timestampDecodeType}) + reg.RegisterTypeDecoder(tMinKey, decodeAdapter{minKeyDecodeValue, minKeyDecodeType}) + reg.RegisterTypeDecoder(tMaxKey, decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType}) + reg.RegisterTypeDecoder(tJavaScript, decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType}) + reg.RegisterTypeDecoder(tSymbol, decodeAdapter{symbolDecodeValue, symbolDecodeType}) + reg.RegisterTypeDecoder(tByteSlice, &byteSliceCodec{}) + reg.RegisterTypeDecoder(tTime, &timeCodec{}) + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{}) + reg.RegisterTypeDecoder(tCoreArray, &arrayCodec{}) + reg.RegisterTypeDecoder(tOID, decodeAdapter{objectIDDecodeValue, objectIDDecodeType}) + reg.RegisterTypeDecoder(tDecimal, decodeAdapter{decimal128DecodeValue, decimal128DecodeType}) + reg.RegisterTypeDecoder(tJSONNumber, decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType}) + reg.RegisterTypeDecoder(tURL, decodeAdapter{urlDecodeValue, urlDecodeType}) + reg.RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(coreDocumentDecodeValue)) + reg.RegisterTypeDecoder(tCodeWithScope, decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType}) + reg.RegisterKindDecoder(reflect.Bool, decodeAdapter{booleanDecodeValue, booleanDecodeType}) + reg.RegisterKindDecoder(reflect.Int, intDecoder) + reg.RegisterKindDecoder(reflect.Int8, intDecoder) + reg.RegisterKindDecoder(reflect.Int16, intDecoder) + reg.RegisterKindDecoder(reflect.Int32, intDecoder) + reg.RegisterKindDecoder(reflect.Int64, intDecoder) + reg.RegisterKindDecoder(reflect.Uint, uintCodec) + reg.RegisterKindDecoder(reflect.Uint8, uintCodec) + reg.RegisterKindDecoder(reflect.Uint16, uintCodec) + reg.RegisterKindDecoder(reflect.Uint32, uintCodec) + reg.RegisterKindDecoder(reflect.Uint64, uintCodec) + reg.RegisterKindDecoder(reflect.Float32, floatDecoder) + reg.RegisterKindDecoder(reflect.Float64, floatDecoder) + reg.RegisterKindDecoder(reflect.Array, ValueDecoderFunc(arrayDecodeValue)) + reg.RegisterKindDecoder(reflect.Map, &mapCodec{}) + reg.RegisterKindDecoder(reflect.Slice, &sliceCodec{}) + reg.RegisterKindDecoder(reflect.String, &stringCodec{}) + reg.RegisterKindDecoder(reflect.Struct, newStructCodec(nil)) + reg.RegisterKindDecoder(reflect.Ptr, &pointerCodec{}) + reg.RegisterTypeMapEntry(TypeDouble, tFloat64) + reg.RegisterTypeMapEntry(TypeString, tString) + reg.RegisterTypeMapEntry(TypeArray, tA) + reg.RegisterTypeMapEntry(TypeBinary, tBinary) + reg.RegisterTypeMapEntry(TypeUndefined, tUndefined) + reg.RegisterTypeMapEntry(TypeObjectID, tOID) + reg.RegisterTypeMapEntry(TypeBoolean, tBool) + reg.RegisterTypeMapEntry(TypeDateTime, tDateTime) + reg.RegisterTypeMapEntry(TypeRegex, tRegex) + reg.RegisterTypeMapEntry(TypeDBPointer, tDBPointer) + reg.RegisterTypeMapEntry(TypeJavaScript, tJavaScript) + reg.RegisterTypeMapEntry(TypeSymbol, tSymbol) + reg.RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope) + reg.RegisterTypeMapEntry(TypeInt32, tInt32) + reg.RegisterTypeMapEntry(TypeInt64, tInt64) + reg.RegisterTypeMapEntry(TypeTimestamp, tTimestamp) + reg.RegisterTypeMapEntry(TypeDecimal128, tDecimal) + reg.RegisterTypeMapEntry(TypeMinKey, tMinKey) + reg.RegisterTypeMapEntry(TypeMaxKey, tMaxKey) + reg.RegisterTypeMapEntry(Type(0), tD) + reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tD) + reg.RegisterInterfaceDecoder(tValueUnmarshaler, ValueDecoderFunc(valueUnmarshalerDecodeValue)) + reg.RegisterInterfaceDecoder(tUnmarshaler, ValueDecoderFunc(unmarshalerDecodeValue)) } -// DDecodeValue is the ValueDecoderFunc for D instances. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dDecodeValue is the ValueDecoderFunc for D instances. +func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Type() != tD { return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } switch vrType := vr.Type(); vrType { case Type(0), TypeEmbeddedDocument: - dc.Ancestor = tD + break case TypeNull: val.Set(reflect.Zero(val.Type())) return vr.ReadNull() @@ -192,7 +162,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v return nil } -func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.Bool { return emptyValue, ValueDecoderError{ Name: "BooleanDecodeValue", @@ -238,16 +208,13 @@ func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr ValueReade return reflect.ValueOf(b), nil } -// BooleanDecodeValue is the ValueDecoderFunc for bool types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// booleanDecodeValue is the ValueDecoderFunc for bool types. +func booleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } - elem, err := dvd.booleanDecodeType(dctx, vr, val.Type()) + elem, err := booleanDecodeType(dctx, vr, val.Type()) if err != nil { return err } @@ -256,7 +223,7 @@ func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 var err error switch vrType := vr.Type(); vrType { @@ -325,7 +292,7 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr ValueReader, t re case reflect.Int64: return reflect.ValueOf(i64), nil case reflect.Int: - if int64(int(i64)) != i64 { // Can we fit this inside of an int + if i64 > math.MaxInt { // Can we fit this inside of an int return emptyValue, fmt.Errorf("%d overflows int", i64) } @@ -339,11 +306,8 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr ValueReader, t re } } -// IntDecodeValue is the ValueDecoderFunc for int types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// intDecodeValue is the ValueDecoderFunc for int types. +func intDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ Name: "IntDecodeValue", @@ -352,7 +316,7 @@ func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, } } - elem, err := dvd.intDecodeType(dc, vr, val.Type()) + elem, err := intDecodeType(dc, vr, val.Type()) if err != nil { return err } @@ -361,7 +325,7 @@ func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, return nil } -func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var f float64 var err error switch vrType := vr.Type(); vrType { @@ -420,11 +384,8 @@ func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr ValueReader } } -// FloatDecodeValue is the ValueDecoderFunc for float types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error { +// floatDecodeValue is the ValueDecoderFunc for float types. +func floatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ Name: "FloatDecodeValue", @@ -433,7 +394,7 @@ func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReade } } - elem, err := dvd.floatDecodeType(ec, vr, val.Type()) + elem, err := floatDecodeType(ec, vr, val.Type()) if err != nil { return err } @@ -442,7 +403,7 @@ func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReade return nil } -func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJavaScript { return emptyValue, ValueDecoderError{ Name: "JavaScriptDecodeValue", @@ -470,16 +431,13 @@ func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader return reflect.ValueOf(JavaScript(js)), nil } -// JavaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// javaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. +func javaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJavaScript { return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } - elem, err := dvd.javaScriptDecodeType(dctx, vr, tJavaScript) + elem, err := javaScriptDecodeType(dctx, vr, tJavaScript) if err != nil { return err } @@ -488,7 +446,7 @@ func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr Val return nil } -func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tSymbol { return emptyValue, ValueDecoderError{ Name: "SymbolDecodeValue", @@ -528,16 +486,13 @@ func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(Symbol(symbol)), nil } -// SymbolDecodeValue is the ValueDecoderFunc for the Symbol type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// symbolDecodeValue is the ValueDecoderFunc for the Symbol type. +func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tSymbol { return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} } - elem, err := dvd.symbolDecodeType(dctx, vr, tSymbol) + elem, err := symbolDecodeType(dctx, vr, tSymbol) if err != nil { return err } @@ -546,7 +501,7 @@ func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tBinary { return emptyValue, ValueDecoderError{ Name: "BinaryDecodeValue", @@ -575,16 +530,13 @@ func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(Binary{Subtype: subtype, Data: data}), nil } -// BinaryDecodeValue is the ValueDecoderFunc for Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// binaryDecodeValue is the ValueDecoderFunc for Binary. +func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tBinary { return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} } - elem, err := dvd.binaryDecodeType(dc, vr, tBinary) + elem, err := binaryDecodeType(dc, vr, tBinary) if err != nil { return err } @@ -593,7 +545,7 @@ func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ Name: "UndefinedDecodeValue", @@ -618,16 +570,13 @@ func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(Undefined{}), nil } -// UndefinedDecodeValue is the ValueDecoderFunc for Undefined. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// undefinedDecodeValue is the ValueDecoderFunc for Undefined. +func undefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tUndefined { return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} } - elem, err := dvd.undefinedDecodeType(dc, vr, tUndefined) + elem, err := undefinedDecodeType(dc, vr, tUndefined) if err != nil { return err } @@ -637,7 +586,7 @@ func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueR } // Accept both 12-byte string and pretty-printed 24-byte hex string formats. -func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tOID { return emptyValue, ValueDecoderError{ Name: "ObjectIDDecodeValue", @@ -682,16 +631,13 @@ func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr ValueRead return reflect.ValueOf(oid), nil } -// ObjectIDDecodeValue is the ValueDecoderFunc for ObjectID. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// objectIDDecodeValue is the ValueDecoderFunc for ObjectID. +func objectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tOID { return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} } - elem, err := dvd.objectIDDecodeType(dc, vr, tOID) + elem, err := objectIDDecodeType(dc, vr, tOID) if err != nil { return err } @@ -700,7 +646,7 @@ func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDateTime { return emptyValue, ValueDecoderError{ Name: "DateTimeDecodeValue", @@ -728,16 +674,13 @@ func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(DateTime(dt)), nil } -// DateTimeDecodeValue is the ValueDecoderFunc for DateTime. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dateTimeDecodeValue is the ValueDecoderFunc for DateTime. +func dateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDateTime { return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} } - elem, err := dvd.dateTimeDecodeType(dc, vr, tDateTime) + elem, err := dateTimeDecodeType(dc, vr, tDateTime) if err != nil { return err } @@ -746,7 +689,7 @@ func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tNull { return emptyValue, ValueDecoderError{ Name: "NullDecodeValue", @@ -771,16 +714,13 @@ func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr ValueReader, t re return reflect.ValueOf(Null{}), nil } -// NullDecodeValue is the ValueDecoderFunc for Null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// nullDecodeValue is the ValueDecoderFunc for Null. +func nullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tNull { return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} } - elem, err := dvd.nullDecodeType(dc, vr, tNull) + elem, err := nullDecodeType(dc, vr, tNull) if err != nil { return err } @@ -789,7 +729,7 @@ func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader return nil } -func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tRegex { return emptyValue, ValueDecoderError{ Name: "RegexDecodeValue", @@ -817,16 +757,13 @@ func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr ValueReader, t r return reflect.ValueOf(Regex{Pattern: pattern, Options: options}), nil } -// RegexDecodeValue is the ValueDecoderFunc for Regex. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// regexDecodeValue is the ValueDecoderFunc for Regex. +func regexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRegex { return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} } - elem, err := dvd.regexDecodeType(dc, vr, tRegex) + elem, err := regexDecodeType(dc, vr, tRegex) if err != nil { return err } @@ -835,7 +772,7 @@ func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReade return nil } -func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dbPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDBPointer { return emptyValue, ValueDecoderError{ Name: "DBPointerDecodeValue", @@ -864,16 +801,13 @@ func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(DBPointer{DB: ns, Pointer: pointer}), nil } -// DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dbPointerDecodeValue is the ValueDecoderFunc for DBPointer. +func dbPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDBPointer { return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } - elem, err := dvd.dBPointerDecodeType(dc, vr, tDBPointer) + elem, err := dbPointerDecodeType(dc, vr, tDBPointer) if err != nil { return err } @@ -882,7 +816,7 @@ func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { +func timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { if reflectType != tTimestamp { return emptyValue, ValueDecoderError{ Name: "TimestampDecodeValue", @@ -910,16 +844,13 @@ func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(Timestamp{T: t, I: incr}), nil } -// TimestampDecodeValue is the ValueDecoderFunc for Timestamp. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// timestampDecodeValue is the ValueDecoderFunc for Timestamp. +func timestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTimestamp { return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } - elem, err := dvd.timestampDecodeType(dc, vr, tTimestamp) + elem, err := timestampDecodeType(dc, vr, tTimestamp) if err != nil { return err } @@ -928,7 +859,7 @@ func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMinKey { return emptyValue, ValueDecoderError{ Name: "MinKeyDecodeValue", @@ -955,16 +886,13 @@ func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(MinKey{}), nil } -// MinKeyDecodeValue is the ValueDecoderFunc for MinKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// minKeyDecodeValue is the ValueDecoderFunc for MinKey. +func minKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMinKey { return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} } - elem, err := dvd.minKeyDecodeType(dc, vr, tMinKey) + elem, err := minKeyDecodeType(dc, vr, tMinKey) if err != nil { return err } @@ -973,7 +901,7 @@ func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMaxKey { return emptyValue, ValueDecoderError{ Name: "MaxKeyDecodeValue", @@ -1000,16 +928,13 @@ func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(MaxKey{}), nil } -// MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// maxKeyDecodeValue is the ValueDecoderFunc for MaxKey. +func maxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMaxKey { return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } - elem, err := dvd.maxKeyDecodeType(dc, vr, tMaxKey) + elem, err := maxKeyDecodeType(dc, vr, tMaxKey) if err != nil { return err } @@ -1018,7 +943,7 @@ func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDecimal { return emptyValue, ValueDecoderError{ Name: "Decimal128DecodeValue", @@ -1046,16 +971,13 @@ func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr ValueRe return reflect.ValueOf(d128), nil } -// Decimal128DecodeValue is the ValueDecoderFunc for Decimal128. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// decimal128DecodeValue is the ValueDecoderFunc for Decimal128. +func decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDecimal { return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} } - elem, err := dvd.decimal128DecodeType(dctx, vr, tDecimal) + elem, err := decimal128DecodeType(dctx, vr, tDecimal) if err != nil { return err } @@ -1064,7 +986,7 @@ func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr Val return nil } -func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJSONNumber { return emptyValue, ValueDecoderError{ Name: "JSONNumberDecodeValue", @@ -1108,16 +1030,13 @@ func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr ValueRe return reflect.ValueOf(jsonNum), nil } -// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// jsonNumberDecodeValue is the ValueDecoderFunc for json.Number. +func jsonNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJSONNumber { return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } - elem, err := dvd.jsonNumberDecodeType(dc, vr, tJSONNumber) + elem, err := jsonNumberDecodeType(dc, vr, tJSONNumber) if err != nil { return err } @@ -1126,7 +1045,7 @@ func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr Value return nil } -func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tURL { return emptyValue, ValueDecoderError{ Name: "URLDecodeValue", @@ -1160,16 +1079,13 @@ func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(urlPtr).Elem(), nil } -// URLDecodeValue is the ValueDecoderFunc for url.URL. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// urlDecodeValue is the ValueDecoderFunc for url.URL. +func urlDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tURL { return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} } - elem, err := dvd.urlDecodeType(dc, vr, tURL) + elem, err := urlDecodeType(dc, vr, tURL) if err != nil { return err } @@ -1178,11 +1094,8 @@ func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, return nil } -// ArrayDecodeValue is the ValueDecoderFunc for array types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// arrayDecodeValue is the ValueDecoderFunc for array types. +func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -1226,9 +1139,9 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - elemsFunc = dvd.decodeD + elemsFunc = decodeD default: - elemsFunc = dvd.decodeDefault + elemsFunc = decodeDefault } elems, err := elemsFunc(dc, vr, val) @@ -1247,11 +1160,8 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade return nil } -// ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// valueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. +func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } @@ -1280,14 +1190,11 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr // NB: this error should be unreachable due to the above checks return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - return m.UnmarshalBSONValue(t, src) + return m.UnmarshalBSONValue(byte(t), src) } -// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// unmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. +func unmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } @@ -1331,11 +1238,8 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr Value return m.UnmarshalBSON(src) } -// CoreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// coreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. +func coreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreDocument { return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -1351,7 +1255,7 @@ func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr ValueRea return err } -func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { +func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { elems := make([]reflect.Value, 0) ar, err := vr.ReadArray() @@ -1421,31 +1325,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, return elems, nil } -func (dvd DefaultValueDecoders) readCodeWithScope(dc DecodeContext, vr ValueReader) (CodeWithScope, error) { - var cws CodeWithScope - - code, dr, err := vr.ReadCodeWithScope() - if err != nil { - return cws, err - } - - scope := reflect.New(tD).Elem() - elems, err := dvd.decodeElemsFromDocumentReader(dc, dr) - if err != nil { - return cws, err - } - - scope.Set(reflect.MakeSlice(tD, 0, len(elems))) - scope.Set(reflect.Append(scope, elems...)) - - cws = CodeWithScope{ - Code: JavaScript(code), - Scope: scope.Interface().(D), - } - return cws, nil -} - -func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tCodeWithScope { return emptyValue, ValueDecoderError{ Name: "CodeWithScopeDecodeValue", @@ -1458,7 +1338,24 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr Val var err error switch vrType := vr.Type(); vrType { case TypeCodeWithScope: - cws, err = dvd.readCodeWithScope(dc, vr) + code, dr, err := vr.ReadCodeWithScope() + if err != nil { + return emptyValue, err + } + + scope := reflect.New(tD).Elem() + elems, err := decodeElemsFromDocumentReader(dc, dr) + if err != nil { + return emptyValue, err + } + + scope.Set(reflect.MakeSlice(tD, 0, len(elems))) + scope.Set(reflect.Append(scope, elems...)) + + cws = CodeWithScope{ + Code: JavaScript(code), + Scope: scope.Interface().(D), + } case TypeNull: err = vr.ReadNull() case TypeUndefined: @@ -1473,16 +1370,13 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr Val return reflect.ValueOf(cws), nil } -// CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// codeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. +func codeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCodeWithScope { return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } - elem, err := dvd.codeWithScopeDecodeType(dc, vr, tCodeWithScope) + elem, err := codeWithScopeDecodeType(dc, vr, tCodeWithScope) if err != nil { return err } @@ -1491,7 +1385,7 @@ func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr Va return nil } -func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) { +func decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) { switch vr.Type() { case Type(0), TypeEmbeddedDocument: default: @@ -1503,10 +1397,10 @@ func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr ValueReader, _ refl return nil, err } - return dvd.decodeElemsFromDocumentReader(dc, dr) + return decodeElemsFromDocumentReader(dc, dr) } -func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) { +func decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) { decoder, err := dc.LookupDecoder(tEmpty) if err != nil { return nil, err diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 31148ab644b..b7fd95ad28f 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -22,12 +22,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var ( - defaultTestStructCodec = newDefaultStructCodec() -) - func TestDefaultValueDecoders(t *testing.T) { - var dvd DefaultValueDecoders var wrong = func(string, string) string { return "wrong" } type mybool bool @@ -71,7 +66,7 @@ func TestDefaultValueDecoders(t *testing.T) { }{ { "BooleanDecodeValue", - ValueDecoderFunc(dvd.BooleanDecodeValue), + ValueDecoderFunc(booleanDecodeValue), []subtest{ { "wrong type", @@ -140,7 +135,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "IntDecodeValue", - ValueDecoderFunc(dvd.IntDecodeValue), + ValueDecoderFunc(intDecodeValue), []subtest{ { "wrong type", @@ -372,7 +367,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultUIntCodec.DecodeValue", - defaultUIntCodec, + &uintCodec{}, []subtest{ { "wrong type", @@ -608,7 +603,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "FloatDecodeValue", - ValueDecoderFunc(dvd.FloatDecodeValue), + ValueDecoderFunc(floatDecodeValue), []subtest{ { "wrong type", @@ -737,7 +732,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultTimeCodec.DecodeValue", - defaultTimeCodec, + &timeCodec{}, []subtest{ { "wrong type", @@ -791,7 +786,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultMapCodec.DecodeValue", - defaultMapCodec, + &mapCodec{}, []subtest{ { "wrong kind", @@ -820,7 +815,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + &DecodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -869,7 +864,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ArrayDecodeValue", - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), []subtest{ { "wrong kind", @@ -906,7 +901,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + &DecodeContext{Registry: newTestRegistry()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -963,7 +958,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultSliceCodec.DecodeValue", - defaultSliceCodec, + &sliceCodec{}, []subtest{ { "wrong kind", @@ -1000,7 +995,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + &DecodeContext{Registry: newTestRegistry()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -1057,7 +1052,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ObjectIDDecodeValue", - ValueDecoderFunc(dvd.ObjectIDDecodeValue), + ValueDecoderFunc(objectIDDecodeValue), []subtest{ { "wrong type", @@ -1144,7 +1139,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "Decimal128DecodeValue", - ValueDecoderFunc(dvd.Decimal128DecodeValue), + ValueDecoderFunc(decimal128DecodeValue), []subtest{ { "wrong type", @@ -1206,7 +1201,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "JSONNumberDecodeValue", - ValueDecoderFunc(dvd.JSONNumberDecodeValue), + ValueDecoderFunc(jsonNumberDecodeValue), []subtest{ { "wrong type", @@ -1300,7 +1295,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "URLDecodeValue", - ValueDecoderFunc(dvd.URLDecodeValue), + ValueDecoderFunc(urlDecodeValue), []subtest{ { "wrong type", @@ -1374,7 +1369,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultByteSliceCodec.DecodeValue", - defaultByteSliceCodec, + &byteSliceCodec{}, []subtest{ { "wrong type", @@ -1442,7 +1437,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultStringCodec.DecodeValue", - defaultStringCodec, + &stringCodec{}, []subtest{ { "symbol", @@ -1472,7 +1467,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ValueUnmarshalerDecodeValue", - ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue), + ValueDecoderFunc(valueUnmarshalerDecodeValue), []subtest{ { "wrong type", @@ -1506,7 +1501,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "UnmarshalerDecodeValue", - ValueDecoderFunc(dvd.UnmarshalerDecodeValue), + ValueDecoderFunc(unmarshalerDecodeValue), []subtest{ { "wrong type", @@ -1551,7 +1546,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "PointerCodec.DecodeValue", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "not valid", nil, nil, nil, nothing, @@ -1585,7 +1580,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "BinaryDecodeValue", - ValueDecoderFunc(dvd.BinaryDecodeValue), + ValueDecoderFunc(binaryDecodeValue), []subtest{ { "wrong type", @@ -1645,7 +1640,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "UndefinedDecodeValue", - ValueDecoderFunc(dvd.UndefinedDecodeValue), + ValueDecoderFunc(undefinedDecodeValue), []subtest{ { "wrong type", @@ -1691,7 +1686,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "DateTimeDecodeValue", - ValueDecoderFunc(dvd.DateTimeDecodeValue), + ValueDecoderFunc(dateTimeDecodeValue), []subtest{ { "wrong type", @@ -1745,7 +1740,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "NullDecodeValue", - ValueDecoderFunc(dvd.NullDecodeValue), + ValueDecoderFunc(nullDecodeValue), []subtest{ { "wrong type", @@ -1783,7 +1778,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "RegexDecodeValue", - ValueDecoderFunc(dvd.RegexDecodeValue), + ValueDecoderFunc(regexDecodeValue), []subtest{ { "wrong type", @@ -1843,7 +1838,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "DBPointerDecodeValue", - ValueDecoderFunc(dvd.DBPointerDecodeValue), + ValueDecoderFunc(dbPointerDecodeValue), []subtest{ { "wrong type", @@ -1908,7 +1903,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "TimestampDecodeValue", - ValueDecoderFunc(dvd.TimestampDecodeValue), + ValueDecoderFunc(timestampDecodeValue), []subtest{ { "wrong type", @@ -1968,7 +1963,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "MinKeyDecodeValue", - ValueDecoderFunc(dvd.MinKeyDecodeValue), + ValueDecoderFunc(minKeyDecodeValue), []subtest{ { "wrong type", @@ -2022,7 +2017,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "MaxKeyDecodeValue", - ValueDecoderFunc(dvd.MaxKeyDecodeValue), + ValueDecoderFunc(maxKeyDecodeValue), []subtest{ { "wrong type", @@ -2076,7 +2071,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "JavaScriptDecodeValue", - ValueDecoderFunc(dvd.JavaScriptDecodeValue), + ValueDecoderFunc(javaScriptDecodeValue), []subtest{ { "wrong type", @@ -2130,7 +2125,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "SymbolDecodeValue", - ValueDecoderFunc(dvd.SymbolDecodeValue), + ValueDecoderFunc(symbolDecodeValue), []subtest{ { "wrong type", @@ -2184,7 +2179,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CoreDocumentDecodeValue", - ValueDecoderFunc(dvd.CoreDocumentDecodeValue), + ValueDecoderFunc(coreDocumentDecodeValue), []subtest{ { "wrong type", @@ -2222,7 +2217,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "StructCodec.DecodeValue", - defaultTestStructCodec, + newStructCodec(nil), []subtest{ { "Not struct", @@ -2252,7 +2247,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CodeWithScopeDecodeValue", - ValueDecoderFunc(dvd.CodeWithScopeDecodeValue), + ValueDecoderFunc(codeWithScopeDecodeValue), []subtest{ { "wrong type", @@ -2313,7 +2308,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CoreArrayDecodeValue", - defaultArrayCodec, + &arrayCodec{}, []subtest{ { "wrong type", @@ -2439,7 +2434,7 @@ func TestDefaultValueDecoders(t *testing.T) { Scope: D{{"bar", nil}}, } val := reflect.New(tCodeWithScope).Elem() - err = dvd.CodeWithScopeDecodeValue(dc, vr, val) + err = codeWithScopeDecodeValue(dc, vr, val) noerr(t, err) got := val.Interface().(CodeWithScope) @@ -2454,7 +2449,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := errors.New("ubsonv error") valUnmarshaler := &testValueUnmarshaler{err: want} - got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) + got := valueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2466,7 +2461,7 @@ func TestDefaultValueDecoders(t *testing.T) { val := reflect.ValueOf(testValueUnmarshaler{}) want := ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} - got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, val) + got := valueUnmarshalerDecodeValue(dc, llvrw, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2491,7 +2486,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := fmt.Errorf("more elements returned in array than can fit inside %T, got 2 elements", val) dc := DecodeContext{Registry: buildDefaultRegistry()} - got := dvd.ArrayDecodeValue(dc, vr, reflect.ValueOf(val)) + got := arrayDecodeValue(dc, vr, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2896,10 +2891,10 @@ func TestDefaultValueDecoders(t *testing.T) { AS: nil, AT: nil, AU: CodeWithScope{Code: "var hello = 'world';", Scope: D{{"pi", 3.14159}}}, - AV: M{"foo": M{"bar": "baz"}}, + AV: M{"foo": D{{"bar", "baz"}}}, AW: D{{"foo", D{{"bar", "baz"}}}}, - AX: map[string]interface{}{"foo": map[string]interface{}{"bar": "baz"}}, - AY: []E{{"foo", []E{{"bar", "baz"}}}}, + AX: map[string]interface{}{"foo": D{{"bar", "baz"}}}, + AY: []E{{"foo", D{{"bar", "baz"}}}}, AZ: D{{"foo", D{{"bar", "baz"}}}}, }, buildDocument(func(doc []byte) []byte { @@ -3311,9 +3306,9 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + dc := DecodeContext{Registry: newTestRegistry()} want := ErrNoTypeMapEntry{Type: tc.bsontype} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := (&emptyInterfaceCodec{}).DecodeValue(dc, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3324,13 +3319,13 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() + reg := newTestRegistry() + reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), + Registry: reg, } want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := (&emptyInterfaceCodec{}).DecodeValue(dc, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3342,13 +3337,13 @@ func TestDefaultValueDecoders(t *testing.T) { } want := errors.New("DecodeValue failure error") llc := &llCodec{t: t, err: want} + reg := newTestRegistry() + reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) + reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), + Registry: reg, } - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, reflect.New(tEmpty).Elem()) + got := (&emptyInterfaceCodec{}).DecodeValue(dc, llvr, reflect.New(tEmpty).Elem()) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3357,14 +3352,14 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val llc := &llCodec{t: t, decodeval: tc.val} + reg := newTestRegistry() + reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) + reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), + Registry: reg, } got := reflect.New(tEmpty).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, got) + err := (&emptyInterfaceCodec{}).DecodeValue(dc, llvr, got) noerr(t, err) if !cmp.Equal(got.Interface(), want, cmp.Comparer(compareDecimal128)) { t.Errorf("Did not receive expected value. got %v; want %v", got.Interface(), want) @@ -3377,7 +3372,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("non-interface{}", func(t *testing.T) { val := uint64(1234567890) want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := (&emptyInterfaceCodec{}).DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3386,7 +3381,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("nil *interface{}", func(t *testing.T) { var val interface{} want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := (&emptyInterfaceCodec{}).DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3396,7 +3391,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) + got := (&emptyInterfaceCodec{}).DecodeValue(DecodeContext{Registry: newTestRegistry()}, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3407,7 +3402,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := D{{"pi", 3.14159}} var got interface{} val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: buildDefaultRegistry()}, vr, val) + err := (&emptyInterfaceCodec{}).DecodeValue(DecodeContext{Registry: buildDefaultRegistry()}, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Errorf("Did not get correct result. got %v; want %v", got, want) @@ -3415,17 +3410,27 @@ func TestDefaultValueDecoders(t *testing.T) { }) t.Run("custom type map entry", func(t *testing.T) { // registering a custom type map entry for both Type(0) anad TypeEmbeddedDocument should cause - // both top-level and embedded documents to decode to registered type when unmarshalling to interface{} - - topLevelRb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(topLevelRb) - defaultValueDecoders.RegisterDefaultDecoders(topLevelRb) - topLevelRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + // the top-level to decode to registered type when unmarshalling to interface{} - embeddedRb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(embeddedRb) - defaultValueDecoders.RegisterDefaultDecoders(embeddedRb) - embeddedRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + topLevelReg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultEncoders(topLevelReg) + registerDefaultDecoders(topLevelReg) + topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + + embeddedReg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultEncoders(embeddedReg) + registerDefaultDecoders(embeddedReg) + embeddedReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) // create doc {"nested": {"foo": 1}} innerDoc := bsoncore.BuildDocument( @@ -3437,39 +3442,41 @@ func TestDefaultValueDecoders(t *testing.T) { bsoncore.AppendDocumentElement(nil, "nested", innerDoc), ) want := M{ - "nested": M{ - "foo": int32(1), - }, + "nested": D{{"foo", int32(1)}}, } testCases := []struct { name string registry *Registry }{ - {"top level", topLevelRb.Build()}, - {"embedded", embeddedRb.Build()}, + {"top level", topLevelReg}, + {"embedded", embeddedReg}, } for _, tc := range testCases { var got interface{} vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: tc.registry}, vr, val) + err := (&emptyInterfaceCodec{}).DecodeValue(DecodeContext{Registry: tc.registry}, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) } } }) - t.Run("ancestor info is used over custom type map entry", func(t *testing.T) { - // If a type map entry is registered for TypeEmbeddedDocument, the decoder should use ancestor - // information if available instead of the registered entry. - - rb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) - rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) - reg := rb.Build() + t.Run("custom type map entry is used if there is no type information", func(t *testing.T) { + // If a type map entry is registered for TypeEmbeddedDocument, the decoder should use it when + // type information is not available. + + reg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultEncoders(reg) + registerDefaultDecoders(reg) + reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) // build document {"nested": {"foo": 10}} inner := bsoncore.BuildDocument( @@ -3481,15 +3488,15 @@ func TestDefaultValueDecoders(t *testing.T) { bsoncore.AppendDocumentElement(nil, "nested", inner), ) want := D{ - {"nested", D{ - {"foo", int32(10)}, + {"nested", M{ + "foo": int32(10), }}, } var got D vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultSliceCodec.DecodeValue(DecodeContext{Registry: reg}, vr, val) + err := (&sliceCodec{}).DecodeValue(DecodeContext{Registry: reg}, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3502,8 +3509,8 @@ func TestDefaultValueDecoders(t *testing.T) { emptyInterfaceErrorDecode := func(DecodeContext, ValueReader, reflect.Value) error { return decodeValueError } - emptyInterfaceErrorRegistry := newTestRegistryBuilder(). - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)).Build() + emptyInterfaceErrorRegistry := newTestRegistry() + emptyInterfaceErrorRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} // using the registry defined above. @@ -3555,11 +3562,14 @@ func TestDefaultValueDecoders(t *testing.T) { outerDoc := buildDocument(bsoncore.AppendDocumentElement(nil, "first", inner1Doc)) // Use a registry that has all default decoders with the custom interface{} decoder that always errors. - nestedRegistryBuilder := newTestRegistryBuilder() - defaultValueDecoders.RegisterDefaultDecoders(nestedRegistryBuilder) - nestedRegistry := nestedRegistryBuilder. - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)). - Build() + nestedRegistry := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultDecoders(nestedRegistry) + nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, @@ -3579,7 +3589,7 @@ func TestDefaultValueDecoders(t *testing.T) { D{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultSliceCodec, + &sliceCodec{}, docEmptyInterfaceErr, }, { @@ -3588,7 +3598,7 @@ func TestDefaultValueDecoders(t *testing.T) { []string{}, &valueReaderWriter{BSONType: TypeArray}, nil, - defaultSliceCodec, + &sliceCodec{}, &DecodeError{ keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type"), @@ -3602,7 +3612,7 @@ func TestDefaultValueDecoders(t *testing.T) { [1]E{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), docEmptyInterfaceErr, }, { @@ -3613,7 +3623,7 @@ func TestDefaultValueDecoders(t *testing.T) { [1]string{}, &valueReaderWriter{BSONType: TypeArray}, nil, - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), &DecodeError{ keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type"), @@ -3625,7 +3635,7 @@ func TestDefaultValueDecoders(t *testing.T) { map[string]interface{}{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultMapCodec, + &mapCodec{}, docEmptyInterfaceErr, }, { @@ -3634,7 +3644,7 @@ func TestDefaultValueDecoders(t *testing.T) { emptyInterfaceStruct{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultTestStructCodec, + newStructCodec(nil), emptyInterfaceStructErr, }, { @@ -3644,8 +3654,8 @@ func TestDefaultValueDecoders(t *testing.T) { "struct - no decoder found", stringStruct{}, NewValueReader(docBytes), - newTestRegistryBuilder().Build(), - defaultTestStructCodec, + newTestRegistry(), + newStructCodec(nil), stringStructErr, }, { @@ -3653,7 +3663,7 @@ func TestDefaultValueDecoders(t *testing.T) { outer{}, NewValueReader(outerDoc), nestedRegistry, - defaultTestStructCodec, + newStructCodec(nil), nestedErr, }, } @@ -3683,7 +3693,7 @@ func TestDefaultValueDecoders(t *testing.T) { dc := DecodeContext{Registry: buildDefaultRegistry()} vr := NewValueReader(outerBytes) val := reflect.New(reflect.TypeOf(outer{})).Elem() - err := defaultTestStructCodec.DecodeValue(dc, vr, val) + err := newStructCodec(nil).DecodeValue(dc, vr, val) var decodeErr *DecodeError assert.True(t, errors.As(err, &decodeErr), "expected DecodeError, got %v of type %T", err, err) @@ -3709,14 +3719,19 @@ func TestDefaultValueDecoders(t *testing.T) { bsoncore.BuildArrayElement(nil, "boolArray", trueValue), ) - rb := newTestRegistryBuilder() - defaultValueDecoders.RegisterDefaultDecoders(rb) - reg := rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))).Build() + reg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultDecoders(reg) + reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) dc := DecodeContext{Registry: reg} vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() - err := defaultValueDecoders.DDecodeValue(dc, vr, val) + err := dDecodeValue(dc, vr, val) assert.Nil(t, err, "DDecodeValue error: %v", err) want := D{ @@ -3735,7 +3750,7 @@ func TestDefaultValueDecoders(t *testing.T) { dc := DecodeContext{Registry: buildDefaultRegistry()} vr := NewValueReader(docBytes) val := reflect.New(reflect.TypeOf(myMap{})).Elem() - err := defaultMapCodec.DecodeValue(dc, vr, val) + err := (&mapCodec{}).DecodeValue(dc, vr, val) assert.Nil(t, err, "DecodeValue error: %v", err) want := myMap{ @@ -3778,8 +3793,13 @@ func buildDocument(elems []byte) []byte { } func buildDefaultRegistry() *Registry { - rb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) - return rb.Build() + reg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultEncoders(reg) + registerDefaultDecoders(reg) + return reg } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index f2773c36e54..c87855551f4 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -9,18 +9,14 @@ package bson import ( "encoding/json" "errors" - "fmt" "math" "net/url" "reflect" "sync" - "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var defaultValueEncoders DefaultValueEncoders - var bvwPool = NewValueWriterPool() var errInvalidValue = errors.New("cannot encode invalid element") @@ -53,73 +49,58 @@ func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error { return nil } -// DefaultValueEncoders is a namespace type for the default ValueEncoders used -// when creating a registry. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -type DefaultValueEncoders struct{} - -// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with +// registerDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with // the provided RegistryBuilder. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { - if rb == nil { - panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) - } - rb. - RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec). - RegisterTypeEncoder(tTime, defaultTimeCodec). - RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeEncoder(tCoreArray, defaultArrayCodec). - RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)). - RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)). - RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)). - RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)). - RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)). - RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)). - RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)). - RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)). - RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)). - RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)). - RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)). - RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)). - RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)). - RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)). - RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)). - RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)). - RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)). - RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)). - RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)). - RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)). - RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)). - RegisterDefaultEncoder(reflect.Map, defaultMapCodec). - RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec). - RegisterDefaultEncoder(reflect.String, defaultStringCodec). - RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()). - RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()). - RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)). - RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)). - RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)) -} - -// BooleanEncodeValue is the ValueEncoderFunc for bool types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) BooleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func registerDefaultEncoders(reg *Registry) { + mapEncoder := &mapCodec{} + uintCodec := &uintCodec{} + + reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) + reg.RegisterTypeEncoder(tTime, &timeCodec{}) + reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) + reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) + reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) + reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) + reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) + reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) + reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) + reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) + reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) + reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) + reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) + reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) + reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) + reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) + reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) + reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) + reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) + reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) + reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) + reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Uint, uintCodec) + reg.RegisterKindEncoder(reflect.Uint8, uintCodec) + reg.RegisterKindEncoder(reflect.Uint16, uintCodec) + reg.RegisterKindEncoder(reflect.Uint32, uintCodec) + reg.RegisterKindEncoder(reflect.Uint64, uintCodec) + reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue)) + reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue)) + reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue)) + reg.RegisterKindEncoder(reflect.Map, mapEncoder) + reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) + reg.RegisterKindEncoder(reflect.String, &stringCodec{}) + reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder)) + reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) + reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) + reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) +} + +// booleanEncodeValue is the ValueEncoderFunc for bool types. +func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -130,11 +111,8 @@ func fitsIn32Bits(i int64) bool { return math.MinInt32 <= i && i <= math.MaxInt32 } -// IntEncodeValue is the ValueEncoderFunc for int types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// intEncodeValue is the ValueEncoderFunc for int types. +func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32: return vw.WriteInt32(int32(val.Int())) @@ -159,36 +137,8 @@ func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw ValueWriter, } } -// UintEncodeValue is the ValueEncoderFunc for uint types. -// -// Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead. -func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Uint8, reflect.Uint16: - return vw.WriteInt32(int32(val.Uint())) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - u64 := val.Uint() - if ec.MinSize && u64 <= math.MaxInt32 { - return vw.WriteInt32(int32(u64)) - } - if u64 > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", u64) - } - return vw.WriteInt64(int64(u64)) - } - - return ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } -} - -// FloatEncodeValue is the ValueEncoderFunc for float types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// floatEncodeValue is the ValueEncoderFunc for float types. +func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -197,48 +147,24 @@ func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw ValueWriter return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} } -// StringEncodeValue is the ValueEncoderFunc for string types. -// -// Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead. -func (dve DefaultValueEncoders) StringEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if val.Kind() != reflect.String { - return ValueEncoderError{ - Name: "StringEncodeValue", - Kinds: []reflect.Kind{reflect.String}, - Received: val, - } - } - - return vw.WriteString(val.String()) -} - -// ObjectIDEncodeValue is the ValueEncoderFunc for ObjectID. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ObjectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. +func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } return vw.WriteObjectID(val.Interface().(ObjectID)) } -// Decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) Decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. +func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } return vw.WriteDecimal128(val.Interface().(Decimal128)) } -// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. +func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } @@ -246,7 +172,7 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw Value // Attempt int first, then float64 if i64, err := jsnum.Int64(); err == nil { - return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64)) + return intEncodeValue(ec, vw, reflect.ValueOf(i64)) } f64, err := jsnum.Float64() @@ -254,14 +180,11 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw Value return err } - return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64)) + return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) } -// URLEncodeValue is the ValueEncoderFunc for url.URL. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// urlEncodeValue is the ValueEncoderFunc for url.URL. +func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } @@ -269,108 +192,8 @@ func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw ValueWriter, return vw.WriteString(u.String()) } -// TimeEncodeValue is the ValueEncoderFunc for time.TIme. -// -// Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead. -func (dve DefaultValueEncoders) TimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTime { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} - } - tt := val.Interface().(time.Time) - dt := NewDateTimeFromTime(tt) - return vw.WriteDateTime(int64(dt)) -} - -// ByteSliceEncodeValue is the ValueEncoderFunc for []byte. -// -// Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) ByteSliceEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tByteSlice { - return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - if val.IsNil() { - return vw.WriteNull() - } - return vw.WriteBinary(val.Interface().([]byte)) -} - -// MapEncodeValue is the ValueEncoderFunc for map[string]* types. -// -// Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead. -func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} - } - - if val.IsNil() { - // If we have a nill map but we can't WriteNull, that means we're probably trying to encode - // to a TopLevel document. We can't currently tell if this is what actually happened, but if - // there's a deeper underlying problem, the error will also be returned from WriteDocument, - // so just continue. The operations on a map reflection value are valid, so we can call - // MapKeys within mapEncodeValue without a problem. - err := vw.WriteNull() - if err == nil { - return nil - } - } - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - return dve.mapEncodeValue(ec, dw, val, nil) -} - -// mapEncodeValue handles encoding of the values of a map. The collisionFn returns -// true if the provided key exists, this is mainly used for inline maps in the -// struct codec. -func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { - - elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) - if err != nil && elemType.Kind() != reflect.Interface { - return err - } - - keys := val.MapKeys() - for _, key := range keys { - if collisionFn != nil && collisionFn(key.String()) { - return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) - } - - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key)) - if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { - return lookupErr - } - - vw, err := dw.WriteDocumentElement(key.String()) - if err != nil { - return err - } - - if errors.Is(lookupErr, errInvalidValue) { - err = vw.WriteNull() - if err != nil { - return err - } - continue - } - - err = currEncoder.EncodeValue(ec, vw, currVal) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() -} - -// ArrayEncodeValue is the ValueEncoderFunc for array types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// arrayEncodeValue is the ValueEncoderFunc for array types. +func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -414,76 +237,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWrite } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { - return lookupErr - } - - vw, err := aw.WriteArrayElement() - if err != nil { - return err - } - - if errors.Is(lookupErr, errInvalidValue) { - err = vw.WriteNull() - if err != nil { - return err - } - continue - } - - err = currEncoder.EncodeValue(ec, vw, currVal) - if err != nil { - return err - } - } - return aw.WriteArrayEnd() -} - -// SliceEncodeValue is the ValueEncoderFunc for slice types. -// -// Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Kind() != reflect.Slice { - return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - - // If we have a []E we want to treat it as a document instead of as an array. - if val.Type().ConvertibleTo(tD) { - d := val.Convert(tD).Interface().(D) - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - for _, e := range d { - err = encodeElement(ec, dw, e) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() - } - - aw, err := vw.WriteArray() - if err != nil { - return err - } - - elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) - if err != nil && elemType.Kind() != reflect.Interface { - return err - } - - for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -509,7 +263,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWrite return aw.WriteArrayEnd() } -func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { +func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } @@ -522,30 +276,8 @@ func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncod return currEncoder, currVal, err } -// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}. -// -// Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tEmpty { - return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - encoder, err := ec.LookupEncoder(val.Elem().Type()) - if err != nil { - return err - } - - return encoder.EncodeValue(ec, vw, val.Elem()) -} - -// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. +func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -569,14 +301,11 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw Va if err != nil { return err } - return copyValueFromBytes(vw, t, data) + return copyValueFromBytes(vw, Type(t), data) } -// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. +func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -603,58 +332,8 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw ValueWr return copyValueFromBytes(vw, TypeEmbeddedDocument, data) } -// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - // Either val or a pointer to val must implement Proxy - switch { - case !val.IsValid(): - return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} - case val.Type().Implements(tProxy): - // If Proxy is implemented on a concrete type, make sure that val isn't a nil pointer - if isImplementationNil(val, tProxy) { - return vw.WriteNull() - } - case reflect.PtrTo(val.Type()).Implements(tProxy) && val.CanAddr(): - val = val.Addr() - default: - return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} - } - - m, ok := val.Interface().(Proxy) - if !ok { - return vw.WriteNull() - } - v, err := m.ProxyBSON() - if err != nil { - return err - } - if v == nil { - encoder, err := ec.LookupEncoder(nil) - if err != nil { - return err - } - return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil)) - } - vv := reflect.ValueOf(v) - switch vv.Kind() { - case reflect.Ptr, reflect.Interface: - vv = vv.Elem() - } - encoder, err := ec.LookupEncoder(vv.Type()) - if err != nil { - return err - } - return encoder.EncodeValue(ec, vw, vv) -} - -// JavaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. +func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -662,11 +341,8 @@ func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw ValueWrite return vw.WriteJavascript(val.String()) } -// SymbolEncodeValue is the ValueEncoderFunc for the Symbol type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. +func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -674,11 +350,8 @@ func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteSymbol(val.String()) } -// BinaryEncodeValue is the ValueEncoderFunc for Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// binaryEncodeValue is the ValueEncoderFunc for Binary. +func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -687,11 +360,8 @@ func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// UndefinedEncodeValue is the ValueEncoderFunc for Undefined. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// undefinedEncodeValue is the ValueEncoderFunc for Undefined. +func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -699,11 +369,8 @@ func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteUndefined() } -// DateTimeEncodeValue is the ValueEncoderFunc for DateTime. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. +func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -711,11 +378,8 @@ func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw ValueWriter, return vw.WriteDateTime(val.Int()) } -// NullEncodeValue is the ValueEncoderFunc for Null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// nullEncodeValue is the ValueEncoderFunc for Null. +func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -723,11 +387,8 @@ func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw ValueWriter, val return vw.WriteNull() } -// RegexEncodeValue is the ValueEncoderFunc for Regex. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// regexEncodeValue is the ValueEncoderFunc for Regex. +func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -737,11 +398,8 @@ func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw ValueWriter, va return vw.WriteRegex(regex.Pattern, regex.Options) } -// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. +func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -751,11 +409,8 @@ func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// TimestampEncodeValue is the ValueEncoderFunc for Timestamp. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// timestampEncodeValue is the ValueEncoderFunc for Timestamp. +func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -765,11 +420,8 @@ func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteTimestamp(ts.T, ts.I) } -// MinKeyEncodeValue is the ValueEncoderFunc for MinKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// minKeyEncodeValue is the ValueEncoderFunc for MinKey. +func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -777,11 +429,8 @@ func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteMinKey() } -// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. +func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -789,11 +438,8 @@ func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteMaxKey() } -// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. +func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -803,11 +449,8 @@ func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw ValueWri return copyDocumentFromBytes(vw, cdoc) } -// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 481c6cb1a1d..3977d0b1419 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -35,7 +35,6 @@ func (ms myStruct) Foo() int { } func TestDefaultValueEncoders(t *testing.T) { - var dve DefaultValueEncoders var wrong = func(string, string) string { return "wrong" } type mybool bool @@ -58,11 +57,9 @@ func TestDefaultValueEncoders(t *testing.T) { d128 := NewDecimal128(12345, 67890) var nilValueMarshaler *testValueMarshaler var nilMarshaler *testMarshaler - var nilProxy *testProxy vmStruct := struct{ V testValueMarshalPtr }{testValueMarshalPtr{t: TypeString, buf: []byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00}}} mStruct := struct{ V testMarshalPtr }{testMarshalPtr{buf: bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))}} - pStruct := struct{ V testProxyPtr }{testProxyPtr{ret: int64(1234567890)}} type subtest struct { name string @@ -80,7 +77,7 @@ func TestDefaultValueEncoders(t *testing.T) { }{ { "BooleanEncodeValue", - ValueEncoderFunc(dve.BooleanEncodeValue), + ValueEncoderFunc(booleanEncodeValue), []subtest{ { "wrong type", @@ -96,7 +93,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "IntEncodeValue", - ValueEncoderFunc(dve.IntEncodeValue), + ValueEncoderFunc(intEncodeValue), []subtest{ { "wrong type", @@ -136,7 +133,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UintEncodeValue", - defaultUIntCodec, + &uintCodec{}, []subtest{ { "wrong type", @@ -177,7 +174,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "FloatEncodeValue", - ValueEncoderFunc(dve.FloatEncodeValue), + ValueEncoderFunc(floatEncodeValue), []subtest{ { "wrong type", @@ -199,7 +196,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimeEncodeValue", - defaultTimeCodec, + &timeCodec{}, []subtest{ { "wrong type", @@ -214,7 +211,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MapEncodeValue", - defaultMapCodec, + &mapCodec{}, []subtest{ { "wrong kind", @@ -235,7 +232,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", map[string]int{"foo": 1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeDocument, fmt.Errorf("no encoder found for int"), @@ -259,7 +256,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty map/success", map[string]interface{}{}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeDocumentEnd, nil, @@ -294,7 +291,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ArrayEncodeValue", - ValueEncoderFunc(dve.ArrayEncodeValue), + ValueEncoderFunc(arrayEncodeValue), []subtest{ { "wrong kind", @@ -315,7 +312,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", [1]int{1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -372,7 +369,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SliceEncodeValue", - defaultSliceCodec, + &sliceCodec{}, []subtest{ { "wrong kind", @@ -393,7 +390,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", []int{1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -433,7 +430,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty slice/success", []interface{}{}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeArrayEnd, nil, @@ -458,7 +455,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ObjectIDEncodeValue", - ValueEncoderFunc(dve.ObjectIDEncodeValue), + ValueEncoderFunc(objectIDEncodeValue), []subtest{ { "wrong type", @@ -477,7 +474,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "Decimal128EncodeValue", - ValueEncoderFunc(dve.Decimal128EncodeValue), + ValueEncoderFunc(decimal128EncodeValue), []subtest{ { "wrong type", @@ -492,7 +489,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JSONNumberEncodeValue", - ValueEncoderFunc(dve.JSONNumberEncodeValue), + ValueEncoderFunc(jsonNumberEncodeValue), []subtest{ { "wrong type", @@ -521,7 +518,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "URLEncodeValue", - ValueEncoderFunc(dve.URLEncodeValue), + ValueEncoderFunc(urlEncodeValue), []subtest{ { "wrong type", @@ -536,7 +533,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ByteSliceEncodeValue", - defaultByteSliceCodec, + &byteSliceCodec{}, []subtest{ { "wrong type", @@ -552,7 +549,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "EmptyInterfaceEncodeValue", - defaultEmptyInterfaceCodec, + &emptyInterfaceCodec{}, []subtest{ { "wrong type", @@ -566,7 +563,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ValueMarshalerEncodeValue", - ValueEncoderFunc(dve.ValueMarshalerEncodeValue), + ValueEncoderFunc(valueMarshalerEncodeValue), []subtest{ { "wrong type", @@ -644,7 +641,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MarshalerEncodeValue", - ValueEncoderFunc(dve.MarshalerEncodeValue), + ValueEncoderFunc(marshalerEncodeValue), []subtest{ { "wrong type", @@ -704,79 +701,9 @@ func TestDefaultValueEncoders(t *testing.T) { }, }, }, - { - "ProxyEncodeValue", - ValueEncoderFunc(dve.ProxyEncodeValue), - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: reflect.ValueOf(wrong)}, - }, - { - "Proxy error", - testProxy{err: errors.New("proxy error")}, - nil, - nil, - nothing, - errors.New("proxy error"), - }, - { - "Lookup error", - testProxy{ret: nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, - nil, - nothing, - ErrNoEncoder{Type: nil}, - }, - { - "success struct implementation", - testProxy{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, - nil, - writeInt64, - nil, - }, - { - "success ptr to struct implementation", - &testProxy{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, - nil, - writeInt64, - nil, - }, - { - "success nil ptr to struct implementation", - nilProxy, - nil, - nil, - writeNull, - nil, - }, - { - "success ptr to ptr implementation", - &testProxyPtr{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, - nil, - writeInt64, - nil, - }, - { - "unaddressable ptr implementation", - testProxyPtr{ret: int64(1234567890)}, - nil, - nil, - nothing, - ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: reflect.ValueOf(testProxyPtr{})}, - }, - }, - }, { "PointerCodec.EncodeValue", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "nil", @@ -814,7 +741,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "pointer implementation addressable interface", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "ValueMarshaler", @@ -832,19 +759,11 @@ func TestDefaultValueEncoders(t *testing.T) { writeDocumentEnd, nil, }, - { - "Proxy", - &pStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, - nil, - writeDocumentEnd, - nil, - }, }, }, { "JavaScriptEncodeValue", - ValueEncoderFunc(dve.JavaScriptEncodeValue), + ValueEncoderFunc(javaScriptEncodeValue), []subtest{ { "wrong type", @@ -859,7 +778,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SymbolEncodeValue", - ValueEncoderFunc(dve.SymbolEncodeValue), + ValueEncoderFunc(symbolEncodeValue), []subtest{ { "wrong type", @@ -874,7 +793,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "BinaryEncodeValue", - ValueEncoderFunc(dve.BinaryEncodeValue), + ValueEncoderFunc(binaryEncodeValue), []subtest{ { "wrong type", @@ -889,7 +808,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UndefinedEncodeValue", - ValueEncoderFunc(dve.UndefinedEncodeValue), + ValueEncoderFunc(undefinedEncodeValue), []subtest{ { "wrong type", @@ -904,7 +823,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DateTimeEncodeValue", - ValueEncoderFunc(dve.DateTimeEncodeValue), + ValueEncoderFunc(dateTimeEncodeValue), []subtest{ { "wrong type", @@ -919,7 +838,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "NullEncodeValue", - ValueEncoderFunc(dve.NullEncodeValue), + ValueEncoderFunc(nullEncodeValue), []subtest{ { "wrong type", @@ -934,7 +853,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "RegexEncodeValue", - ValueEncoderFunc(dve.RegexEncodeValue), + ValueEncoderFunc(regexEncodeValue), []subtest{ { "wrong type", @@ -949,7 +868,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DBPointerEncodeValue", - ValueEncoderFunc(dve.DBPointerEncodeValue), + ValueEncoderFunc(dbPointerEncodeValue), []subtest{ { "wrong type", @@ -971,7 +890,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimestampEncodeValue", - ValueEncoderFunc(dve.TimestampEncodeValue), + ValueEncoderFunc(timestampEncodeValue), []subtest{ { "wrong type", @@ -986,7 +905,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MinKeyEncodeValue", - ValueEncoderFunc(dve.MinKeyEncodeValue), + ValueEncoderFunc(minKeyEncodeValue), []subtest{ { "wrong type", @@ -1001,7 +920,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MaxKeyEncodeValue", - ValueEncoderFunc(dve.MaxKeyEncodeValue), + ValueEncoderFunc(maxKeyEncodeValue), []subtest{ { "wrong type", @@ -1016,7 +935,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreDocumentEncodeValue", - ValueEncoderFunc(dve.CoreDocumentEncodeValue), + ValueEncoderFunc(coreDocumentEncodeValue), []subtest{ { "wrong type", @@ -1074,7 +993,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "StructEncodeValue", - defaultTestStructCodec, + newStructCodec(&mapCodec{}), []subtest{ { "interface value", @@ -1096,7 +1015,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CodeWithScopeEncodeValue", - ValueEncoderFunc(dve.CodeWithScopeEncodeValue), + ValueEncoderFunc(codeWithScopeEncodeValue), []subtest{ { "wrong type", @@ -1131,7 +1050,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreArrayEncodeValue", - defaultArrayCodec, + &arrayCodec{}, []subtest{ { "wrong type", @@ -1555,10 +1474,8 @@ func TestDefaultValueEncoders(t *testing.T) { AC Decimal128 AD *time.Time AE testValueMarshaler - AF Proxy - AG testProxy - AH map[string]interface{} - AI CodeWithScope + AF map[string]interface{} + AG CodeWithScope }{ A: true, B: 123, @@ -1584,10 +1501,8 @@ func TestDefaultValueEncoders(t *testing.T) { AC: decimal128, AD: &now, AE: testValueMarshaler{t: TypeString, buf: bsoncore.AppendString(nil, "hello, world")}, - AF: testProxy{ret: struct{ Hello string }{Hello: "world!"}}, - AG: testProxy{ret: struct{ Pi float64 }{Pi: 3.14159}}, - AH: nil, - AI: CodeWithScope{Code: "var hello = 'world';", Scope: D{{"pi", 3.14159}}}, + AF: nil, + AG: CodeWithScope{Code: "var hello = 'world';", Scope: D{{"pi", 3.14159}}}, }, buildDocument(func(doc []byte) []byte { doc = bsoncore.AppendBooleanElement(doc, "a", true) @@ -1612,10 +1527,8 @@ func TestDefaultValueEncoders(t *testing.T) { doc = bsoncore.AppendDecimal128Element(doc, "ac", decimal128.h, decimal128.l) doc = bsoncore.AppendDateTimeElement(doc, "ad", now.UnixNano()/int64(time.Millisecond)) doc = bsoncore.AppendStringElement(doc, "ae", "hello, world") - doc = bsoncore.AppendDocumentElement(doc, "af", buildDocument(bsoncore.AppendStringElement(nil, "hello", "world!"))) - doc = bsoncore.AppendDocumentElement(doc, "ag", buildDocument(bsoncore.AppendDoubleElement(nil, "pi", 3.14159))) - doc = bsoncore.AppendNullElement(doc, "ah") - doc = bsoncore.AppendCodeWithScopeElement(doc, "ai", + doc = bsoncore.AppendNullElement(doc, "af") + doc = bsoncore.AppendCodeWithScopeElement(doc, "ag", "var hello = 'world';", buildDocument(bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), ) return doc @@ -1650,8 +1563,6 @@ func TestDefaultValueEncoders(t *testing.T) { AC []Decimal128 AD []*time.Time AE []testValueMarshaler - AF []Proxy - AG []testProxy }{ A: []bool{true}, B: []int32{123}, @@ -1685,14 +1596,6 @@ func TestDefaultValueEncoders(t *testing.T) { {t: TypeString, buf: bsoncore.AppendString(nil, "hello")}, {t: TypeString, buf: bsoncore.AppendString(nil, "world")}, }, - AF: []Proxy{ - testProxy{ret: struct{ Hello string }{Hello: "world!"}}, - testProxy{ret: struct{ Foo string }{Foo: "bar"}}, - }, - AG: []testProxy{ - {ret: struct{ One int64 }{One: 1234567890}}, - {ret: struct{ Pi float64 }{Pi: 3.14159}}, - }, }, buildDocument(func(doc []byte) []byte { doc = appendArrayElement(doc, "a", bsoncore.AppendBooleanElement(nil, "0", true)) @@ -1742,22 +1645,6 @@ func TestDefaultValueEncoders(t *testing.T) { doc = appendArrayElement(doc, "ae", bsoncore.AppendStringElement(bsoncore.AppendStringElement(nil, "0", "hello"), "1", "world"), ) - doc = appendArrayElement(doc, "af", - bsoncore.AppendDocumentElement( - bsoncore.AppendDocumentElement(nil, "0", - bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "hello", "world!")), - ), "1", - bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "foo", "bar")), - ), - ) - doc = appendArrayElement(doc, "ag", - bsoncore.AppendDocumentElement( - bsoncore.AppendDocumentElement(nil, "0", - bsoncore.BuildDocument(nil, bsoncore.AppendInt64Element(nil, "one", 1234567890)), - ), "1", - bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), - ), - ) return doc }(nil)), nil, @@ -1832,7 +1719,7 @@ func TestDefaultValueEncoders(t *testing.T) { t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) { val := reflect.New(tEmpty).Elem() llvrw := new(valueReaderWriter) - err := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: newTestRegistryBuilder().Build()}, llvrw, val) + err := (&emptyInterfaceCodec{}).EncodeValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) noerr(t, err) if llvrw.invoked != writeNull { t.Errorf("Incorrect method called. got %v; want %v", llvrw.invoked, writeNull) @@ -1843,7 +1730,7 @@ func TestDefaultValueEncoders(t *testing.T) { val := reflect.New(tEmpty).Elem() val.Set(reflect.ValueOf(int64(1234567890))) llvrw := new(valueReaderWriter) - got := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: newTestRegistryBuilder().Build()}, llvrw, val) + got := (&emptyInterfaceCodec{}).EncodeValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) want := ErrNoEncoder{Type: tInt64} if !assert.CompareErrors(got, want) { t.Errorf("Did not receive expected error. got %v; want %v", got, want) @@ -1857,8 +1744,8 @@ type testValueMarshalPtr struct { err error } -func (tvm *testValueMarshalPtr) MarshalBSONValue() (Type, []byte, error) { - return tvm.t, tvm.buf, tvm.err +func (tvm *testValueMarshalPtr) MarshalBSONValue() (byte, []byte, error) { + return byte(tvm.t), tvm.buf, tvm.err } type testMarshalPtr struct { @@ -1869,17 +1756,3 @@ type testMarshalPtr struct { func (tvm *testMarshalPtr) MarshalBSON() ([]byte, error) { return tvm.buf, tvm.err } - -type testProxy struct { - ret interface{} - err error -} - -func (tp testProxy) ProxyBSON() (interface{}, error) { return tp.ret, tp.err } - -type testProxyPtr struct { - ret interface{} - err error -} - -func (tp *testProxyPtr) ProxyBSON() (interface{}, error) { return tp.ret, tp.err } diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index e0af34c9423..80d44d8c664 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -8,47 +8,22 @@ package bson import ( "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// EmptyInterfaceCodec is the Codec used for interface{} values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// EmptyInterfaceCodec registered. -type EmptyInterfaceCodec struct { - // DecodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the +// emptyInterfaceCodec is the Codec used for interface{} values. +type emptyInterfaceCodec struct { + // decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the // "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary. - // - // Deprecated: Use bson.Decoder.BinaryAsSlice instead. - DecodeBinaryAsSlice bool + decodeBinaryAsSlice bool } -var ( - defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() - - // Assert that defaultEmptyInterfaceCodec satisfies the typeDecoder interface, which allows it - // to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultEmptyInterfaceCodec -) - -// NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// EmptyInterfaceCodec registered. -func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec { - interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...) - - codec := EmptyInterfaceCodec{} - if interfaceOpt.DecodeBinaryAsSlice != nil { - codec.DecodeBinaryAsSlice = *interfaceOpt.DecodeBinaryAsSlice - } - return &codec -} +// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it +// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a +// collection. +var _ typeDecoder = &emptyInterfaceCodec{} // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } @@ -64,7 +39,7 @@ func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val return encoder.EncodeValue(ec, vw, val.Elem()) } -func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { +func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { if dc.defaultDocumentType != nil { @@ -72,12 +47,6 @@ func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val // that type. return dc.defaultDocumentType, nil } - if dc.Ancestor != nil { - // Using ancestor information rather than looking up the type map entry forces consistent decoding. - // If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry - // has been registered. - return dc.Ancestor, nil - } } rtype, err := dc.LookupTypeMapEntry(valueType) @@ -100,12 +69,14 @@ func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val if err == nil { return rtype, nil } + // fallback to bson.D + return tD, nil } return nil, err } -func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (eic *emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tEmpty { return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)} } @@ -130,7 +101,7 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re return emptyValue, err } - if (eic.DecodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { + if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { binElem := elem.Interface().(Binary) if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld { elem = reflect.ValueOf(binElem.Data) @@ -141,7 +112,7 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re } // DecodeValue is the ValueDecoderFunc for interface{}. -func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (eic *emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tEmpty { return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} } diff --git a/bson/extjson_wrappers.go b/bson/extjson_wrappers.go index 69ac78ccb5b..e98b6428bc7 100644 --- a/bson/extjson_wrappers.go +++ b/bson/extjson_wrappers.go @@ -92,9 +92,9 @@ func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) { return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t) } - i, err := strconv.ParseInt(val.v.(string), 16, 64) + i, err := strconv.ParseUint(val.v.(string), 16, 8) if err != nil { - return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string)) + return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err) } subType = byte(i) diff --git a/bson/map_codec.go b/bson/map_codec.go index fddcc5c8b71..4f33484b2f6 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -12,34 +12,21 @@ import ( "fmt" "reflect" "strconv" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -var defaultMapCodec = NewMapCodec() - -// MapCodec is the Codec used for map values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// MapCodec registered. -type MapCodec struct { +// mapCodec is the Codec used for map values. +type mapCodec struct { // DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination // value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: Use bson.Decoder.ZeroMaps instead. - DecodeZerosMap bool + decodeZerosMap bool // EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of // BSON null. - // - // Deprecated: Use bson.Encoder.NilMapAsEmpty instead. - EncodeNilAsEmpty bool + encodeNilAsEmpty bool // EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprintf() instead of the default string conversion logic. - // - // Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt instead. - EncodeKeysWithStringer bool + encodeKeysWithStringer bool } // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key. @@ -58,33 +45,13 @@ type KeyUnmarshaler interface { UnmarshalKey(key string) error } -// NewMapCodec returns a MapCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// MapCodec registered. -func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec { - mapOpt := bsonoptions.MergeMapCodecOptions(opts...) - - codec := MapCodec{} - if mapOpt.DecodeZerosMap != nil { - codec.DecodeZerosMap = *mapOpt.DecodeZerosMap - } - if mapOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty - } - if mapOpt.EncodeKeysWithStringer != nil { - codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer - } - return &codec -} - // EncodeValue is the ValueEncoder for map[*]* types. -func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (mc *mapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Map { return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } - if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty { + if val.IsNil() && !mc.encodeNilAsEmpty && !ec.nilMapAsEmpty { // If we have a nil map but we can't WriteNull, that means we're probably trying to encode // to a TopLevel document. We can't currently tell if this is what actually happened, but if // there's a deeper underlying problem, the error will also be returned from WriteDocument, @@ -101,13 +68,17 @@ func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Va return err } - return mc.mapEncodeValue(ec, dw, val, nil) + err = mc.encodeMapElements(ec, dw, val, nil) + if err != nil { + return err + } + return dw.WriteDocumentEnd() } -// mapEncodeValue handles encoding of the values of a map. The collisionFn returns +// encodeMapElements handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) encodeMapElements(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() encoder, err := ec.LookupEncoder(elemType) @@ -126,7 +97,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) } - currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key)) + currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.MapIndex(key)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -150,11 +121,11 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl } } - return dw.WriteDocumentEnd() + return nil } // DecodeValue is the ValueDecoder for map[string/decimal]* types. -func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) { return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } @@ -180,7 +151,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va val.Set(reflect.MakeMap(val.Type())) } - if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) { + if val.Len() > 0 && (mc.decodeZerosMap || dc.zeroMaps) { clearMap(val) } @@ -190,10 +161,6 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va return err } - if eType == tEmpty { - dc.Ancestor = val.Type() - } - keyType := val.Type().Key() for { @@ -227,8 +194,8 @@ func clearMap(m reflect.Value) { } } -func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { - if mc.EncodeKeysWithStringer || encodeKeysWithStringer { +func (mc *mapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { + if mc.encodeKeysWithStringer || encodeKeysWithStringer { return fmt.Sprint(val), nil } @@ -273,12 +240,12 @@ func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (s var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem() var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() -func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { +func (mc *mapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { keyVal := reflect.ValueOf(key) var err error switch { // First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler - case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): + case !mc.encodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): keyVal = reflect.New(keyType) v := keyVal.Interface().(KeyUnmarshaler) err = v.UnmarshalKey(key) @@ -308,7 +275,7 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, } keyVal = reflect.ValueOf(n).Convert(keyType) case reflect.Float32, reflect.Float64: - if mc.EncodeKeysWithStringer { + if mc.encodeKeysWithStringer { parsed, err := strconv.ParseFloat(key, 64) if err != nil { return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err) diff --git a/bson/marshal.go b/bson/marshal.go index 573de163984..d6e687ddaea 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -34,7 +34,7 @@ type Marshaler interface { // create custom BSON marshaling behavior for an entire BSON document, implement // the Marshaler interface instead. type ValueMarshaler interface { - MarshalBSONValue() (Type, []byte, error) + MarshalBSONValue() (typ byte, data []byte, err error) } // Pool of buffers for marshalling BSON. diff --git a/bson/marshal_value_cases_test.go b/bson/marshal_value_cases_test.go index 29356ece670..289aa3543f9 100644 --- a/bson/marshal_value_cases_test.go +++ b/bson/marshal_value_cases_test.go @@ -37,13 +37,13 @@ type marshalValueMarshaler struct { var _ ValueMarshaler = marshalValueMarshaler{} -func (mvi marshalValueMarshaler) MarshalBSONValue() (Type, []byte, error) { - return TypeInt32, bsoncore.AppendInt32(nil, int32(mvi.Foo)), nil +func (mvi marshalValueMarshaler) MarshalBSONValue() (byte, []byte, error) { + return byte(TypeInt32), bsoncore.AppendInt32(nil, int32(mvi.Foo)), nil } var _ ValueUnmarshaler = &marshalValueMarshaler{} -func (mvi *marshalValueMarshaler) UnmarshalBSONValue(_ Type, b []byte) error { +func (mvi *marshalValueMarshaler) UnmarshalBSONValue(_ byte, b []byte) error { v, _, _ := bsoncore.ReadInt32(b) mvi.Foo = int(v) return nil diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 6651509983c..abdf8d5cfca 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -471,7 +471,7 @@ func (t *prefixPtr) GetBSON() (interface{}, error) { func (t *prefixPtr) SetBSON(raw bson.RawValue) error { var s string if raw.Type == 0x0A { - return ErrSetZero + return bson.ErrMgoSetZero } rval := reflect.ValueOf(&s).Elem() decoder, err := Registry.LookupDecoder(rval.Type()) @@ -498,7 +498,7 @@ func (t prefixVal) GetBSON() (interface{}, error) { func (t *prefixVal) SetBSON(raw bson.RawValue) error { var s string if raw.Type == 0x0A { - return ErrSetZero + return bson.ErrMgoSetZero } rval := reflect.ValueOf(&s).Elem() decoder, err := Registry.LookupDecoder(rval.Type()) @@ -1019,14 +1019,8 @@ func TestUnmarshalSetterErrors(t *testing.T) { assert.Equal(t, "1", m["abc"].Received, "expected m[\"abc\"].Received to be: %v, got: %v", "1", m["abc"].Received) } -func TestDMap(t *testing.T) { - d := bson.D{{"a", 1}, {"b", 2}} - want := bson.M{"a": 1, "b": 2} - assert.True(t, reflect.DeepEqual(want, d.Map()), "expected: %v, got: %v", want, d.Map()) -} - func TestUnmarshalSetterErrSetZero(t *testing.T) { - setterResult["foo"] = ErrSetZero + setterResult["foo"] = bson.ErrMgoSetZero defer delete(setterResult, "field") buf := new(bytes.Buffer) @@ -1528,7 +1522,6 @@ var twoWayCrossItems = []crossTypeItem{ {&inlineMap{A: 1, M: nil}, map[string]interface{}{"a": 1}}, {&inlineMapInt{A: 1, M: map[string]int{"b": 2}}, map[string]int{"a": 1, "b": 2}}, {&inlineMapInt{A: 1, M: nil}, map[string]int{"a": 1}}, - {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}}, {&inlineUnexported{M: map[string]interface{}{"b": 1}, unexported: unexported{A: 2}}, map[string]interface{}{"b": 1, "a": 2}}, // []byte <=> Binary @@ -1557,18 +1550,6 @@ var twoWayCrossItems = []crossTypeItem{ {&struct{ V time.Time }{time.Unix(-62135596799, 1e6).UTC()}, map[string]interface{}{"v": time.Unix(-62135596799, 1e6).UTC()}}, - // bson.D <=> []DocElem - {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}}, - {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}}, - {&struct{ V MyD }{MyD{{"a", 1}}}, &bson.D{{"v", bson.D{{"a", 1}}}}}, - - // bson.M <=> map - {&bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": MyM{"b": 1, "c": 2}}}, - {&bson.M{"a": bson.M{"b": 1, "c": 2}}, map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}}, - - // bson.M <=> map[MyString] - {&bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}}, - // json.Number <=> int64, float64 {&struct{ N json.Number }{"5"}, map[string]interface{}{"n": int64(5)}}, {&struct{ N json.Number }{"5.05"}, map[string]interface{}{"n": 5.05}}, @@ -1591,6 +1572,25 @@ var oneWayCrossItems = []crossTypeItem{ {&struct { V struct{ v time.Time } `bson:",omitempty"` }{}, map[string]interface{}{}}, + + {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": bson.M{"c": 3}}}, + {map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}, &inlineMapMyM{A: 1, M: MyM{"b": bson.M{"c": 3}}}}, + + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.M{"b": 1, "c": 2}}}}, + + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", bson.M{"b": 1, "c": 2}}}}, + {&MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.M{"b": 1, "c": 2}}}}, + + {&struct{ V MyD }{MyD{{"a", 1}}}, &bson.D{{"v", bson.M{"a": 1}}}}, + {&bson.D{{"v", bson.D{{"a", 1}}}}, &struct{ V MyD }{MyD{{"a", 1}}}}, + + {&bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": bson.M{"b": 1, "c": 2}}}, + {MyM{"a": MyM{"b": 1, "c": 2}}, &bson.M{"a": bson.M{"b": 1, "c": 2}}}, + + {map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}, &bson.M{"a": bson.M{"b": 1, "c": 2}}}, + + {&bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": bson.M{"b": 1, "c": 2}}}, + {map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}, &bson.M{"a": bson.M{"b": 1, "c": 2}}}, } func testCrossPair(t *testing.T, dump interface{}, load interface{}) { diff --git a/bson/mgocompat/registry.go b/bson/mgocompat/registry.go index 7024ab9fdc0..7ffb90b22e4 100644 --- a/bson/mgocompat/registry.go +++ b/bson/mgocompat/registry.go @@ -7,106 +7,12 @@ package mgocompat import ( - "errors" - "reflect" - "time" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsonoptions" -) - -var ( - // ErrSetZero may be returned from a SetBSON method to have the value set to its respective zero value. - ErrSetZero = errors.New("set to zero") - - tInt = reflect.TypeOf(int(0)) - tTime = reflect.TypeOf(time.Time{}) - tM = reflect.TypeOf(bson.M{}) - tInterfaceSlice = reflect.TypeOf([]interface{}{}) - tByteSlice = reflect.TypeOf([]byte{}) - tEmpty = reflect.TypeOf((*interface{})(nil)).Elem() - tGetter = reflect.TypeOf((*Getter)(nil)).Elem() - tSetter = reflect.TypeOf((*Setter)(nil)).Elem() ) // Registry is the mgo compatible bson.Registry. It contains the default and // primitive codecs with mgo compatible options. -var Registry = NewRegistryBuilder().Build() +var Registry = bson.NewMgoRegistry() // RespectNilValuesRegistry is the bson.Registry compatible with mgo withSetRespectNilValues set to true. -var RespectNilValuesRegistry = NewRespectNilValuesRegistryBuilder().Build() - -// NewRegistryBuilder creates a new bson.RegistryBuilder configured with the default encoders and -// decoders from the bson.DefaultValueEncoders and bson.DefaultValueDecoders types and the -// PrimitiveCodecs type in this package. -func NewRegistryBuilder() *bson.RegistryBuilder { - rb := bson.NewRegistryBuilder() - bson.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) - bson.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) - - structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, - bsonoptions.StructCodec(). - SetDecodeZeroStruct(true). - SetEncodeOmitDefaultStruct(true). - SetOverwriteDuplicatedInlinedFields(false). - SetAllowUnexportedFields(true)) - emptyInterCodec := bson.NewEmptyInterfaceCodec( - bsonoptions.EmptyInterfaceCodec(). - SetDecodeBinaryAsSlice(true)) - mapCodec := bson.NewMapCodec( - bsonoptions.MapCodec(). - SetDecodeZerosMap(true). - SetEncodeNilAsEmpty(true). - SetEncodeKeysWithStringer(true)) - uintcodec := bson.NewUIntCodec(bsonoptions.UIntCodec().SetEncodeToMinSize(true)) - - rb.RegisterTypeDecoder(tEmpty, emptyInterCodec). - RegisterDefaultDecoder(reflect.String, bson.NewStringCodec(bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false))). - RegisterDefaultDecoder(reflect.Struct, structcodec). - RegisterDefaultDecoder(reflect.Map, mapCodec). - RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(true))). - RegisterDefaultEncoder(reflect.Struct, structcodec). - RegisterDefaultEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(true))). - RegisterDefaultEncoder(reflect.Map, mapCodec). - RegisterDefaultEncoder(reflect.Uint, uintcodec). - RegisterDefaultEncoder(reflect.Uint8, uintcodec). - RegisterDefaultEncoder(reflect.Uint16, uintcodec). - RegisterDefaultEncoder(reflect.Uint32, uintcodec). - RegisterDefaultEncoder(reflect.Uint64, uintcodec). - RegisterTypeMapEntry(bson.TypeInt32, tInt). - RegisterTypeMapEntry(bson.TypeDateTime, tTime). - RegisterTypeMapEntry(bson.TypeArray, tInterfaceSlice). - RegisterTypeMapEntry(bson.Type(0), tM). - RegisterTypeMapEntry(bson.TypeEmbeddedDocument, tM). - RegisterHookEncoder(tGetter, bson.ValueEncoderFunc(GetterEncodeValue)). - RegisterHookDecoder(tSetter, bson.ValueDecoderFunc(SetterDecodeValue)) - - return rb -} - -// NewRespectNilValuesRegistryBuilder creates a new bson.RegistryBuilder configured to behave like mgo/bson -// with RespectNilValues set to true. -func NewRespectNilValuesRegistryBuilder() *bson.RegistryBuilder { - rb := NewRegistryBuilder() - - structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, - bsonoptions.StructCodec(). - SetDecodeZeroStruct(true). - SetEncodeOmitDefaultStruct(true). - SetOverwriteDuplicatedInlinedFields(false). - SetAllowUnexportedFields(true)) - mapCodec := bson.NewMapCodec( - bsonoptions.MapCodec(). - SetDecodeZerosMap(true). - SetEncodeNilAsEmpty(false)) - - rb.RegisterDefaultDecoder(reflect.Struct, structcodec). - RegisterDefaultDecoder(reflect.Map, mapCodec). - RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(false))). - RegisterDefaultEncoder(reflect.Struct, structcodec). - RegisterDefaultEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(false))). - RegisterDefaultEncoder(reflect.Map, mapCodec) - - return rb -} +var RespectNilValuesRegistry = bson.NewRespectNilValuesMgoRegistry() diff --git a/bson/mgocompat/setter_getter.go b/bson/mgocompat/setter_getter.go deleted file mode 100644 index fc620fbba85..00000000000 --- a/bson/mgocompat/setter_getter.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package mgocompat - -import ( - "errors" - "reflect" - - "go.mongodb.org/mongo-driver/bson" -) - -// Setter interface: a value implementing the bson.Setter interface will receive the BSON -// value via the SetBSON method during unmarshaling, and the object -// itself will not be changed as usual. -// -// If setting the value works, the method should return nil or alternatively -// mgocompat.ErrSetZero to set the respective field to its zero value (nil for -// pointer types). If SetBSON returns a non-nil error, the unmarshalling -// procedure will stop and error out with the provided value. -// -// This interface is generally useful in pointer receivers, since the method -// will want to change the receiver. A type field that implements the Setter -// interface doesn't have to be a pointer, though. -// -// For example: -// -// type MyString string -// -// func (s *MyString) SetBSON(raw bson.RawValue) error { -// return raw.Unmarshal(s) -// } -type Setter interface { - SetBSON(raw bson.RawValue) error -} - -// Getter interface: a value implementing the bson.Getter interface will have its GetBSON -// method called when the given value has to be marshalled, and the result -// of this method will be marshaled in place of the actual object. -// -// If GetBSON returns return a non-nil error, the marshalling procedure -// will stop and error out with the provided value. -type Getter interface { - GetBSON() (interface{}, error) -} - -// SetterDecodeValue is the ValueDecoderFunc for Setter types. -func SetterDecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { - if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} - } - - if val.Kind() == reflect.Ptr && val.IsNil() { - if !val.CanSet() { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} - } - val.Set(reflect.New(val.Type().Elem())) - } - - if !val.Type().Implements(tSetter) { - if !val.CanAddr() { - return bson.ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val} - } - val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. - } - - t, src, err := bson.CopyValueToBytes(vr) - if err != nil { - return err - } - - m, ok := val.Interface().(Setter) - if !ok { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} - } - if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil { - if !errors.Is(err, ErrSetZero) { - return err - } - val.Set(reflect.Zero(val.Type())) - } - return nil -} - -// GetterEncodeValue is the ValueEncoderFunc for Getter types. -func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error { - // Either val or a pointer to val must implement Getter - switch { - case !val.IsValid(): - return bson.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} - case val.Type().Implements(tGetter): - // If Getter is implemented on a concrete type, make sure that val isn't a nil pointer - if isImplementationNil(val, tGetter) { - return vw.WriteNull() - } - case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr(): - val = val.Addr() - default: - return bson.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} - } - - m, ok := val.Interface().(Getter) - if !ok { - return vw.WriteNull() - } - x, err := m.GetBSON() - if err != nil { - return err - } - if x == nil { - return vw.WriteNull() - } - vv := reflect.ValueOf(x) - encoder, err := ec.Registry.LookupEncoder(vv.Type()) - if err != nil { - return err - } - return encoder.EncodeValue(ec, vw, vv) -} - -// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type -func isImplementationNil(val reflect.Value, inter reflect.Type) bool { - vt := val.Type() - for vt.Kind() == reflect.Ptr { - vt = vt.Elem() - } - return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() -} diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go new file mode 100644 index 00000000000..15c67c43dc6 --- /dev/null +++ b/bson/mgoregistry.go @@ -0,0 +1,184 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "errors" + "reflect" +) + +var ( + // ErrMgoSetZero may be returned from a SetBSON method to have the value set to its respective zero value. + ErrMgoSetZero = errors.New("set to zero") + + tInt = reflect.TypeOf(int(0)) + tM = reflect.TypeOf(M{}) + tInterfaceSlice = reflect.TypeOf([]interface{}{}) + tGetter = reflect.TypeOf((*getter)(nil)).Elem() + tSetter = reflect.TypeOf((*setter)(nil)).Elem() +) + +// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. +func NewMgoRegistry() *Registry { + mapCodec := &mapCodec{ + decodeZerosMap: true, + encodeNilAsEmpty: true, + encodeKeysWithStringer: true, + } + structCodec := &structCodec{ + inlineMapEncoder: mapCodec, + decodeZeroStruct: true, + encodeOmitDefaultStruct: true, + allowUnexportedFields: true, + } + uintCodec := &uintCodec{encodeToMinSize: true} + + reg := NewRegistry() + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) + reg.RegisterKindDecoder(reflect.Struct, structCodec) + reg.RegisterKindDecoder(reflect.Map, mapCodec) + reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) + reg.RegisterKindEncoder(reflect.Struct, structCodec) + reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) + reg.RegisterKindEncoder(reflect.Map, mapCodec) + reg.RegisterKindEncoder(reflect.Uint, uintCodec) + reg.RegisterKindEncoder(reflect.Uint8, uintCodec) + reg.RegisterKindEncoder(reflect.Uint16, uintCodec) + reg.RegisterKindEncoder(reflect.Uint32, uintCodec) + reg.RegisterKindEncoder(reflect.Uint64, uintCodec) + reg.RegisterTypeMapEntry(TypeInt32, tInt) + reg.RegisterTypeMapEntry(TypeDateTime, tTime) + reg.RegisterTypeMapEntry(TypeArray, tInterfaceSlice) + reg.RegisterTypeMapEntry(Type(0), tM) + reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tM) + reg.RegisterInterfaceEncoder(tGetter, ValueEncoderFunc(getterEncodeValue)) + reg.RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(setterDecodeValue)) + return reg +} + +// NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson +// with RespectNilValues set to true. +func NewRespectNilValuesMgoRegistry() *Registry { + mapCodec := &mapCodec{ + decodeZerosMap: true, + } + + reg := NewMgoRegistry() + reg.RegisterKindDecoder(reflect.Map, mapCodec) + reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) + reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) + reg.RegisterKindEncoder(reflect.Map, mapCodec) + return reg +} + +// setter interface: a value implementing the bson.Setter interface will receive the BSON +// value via the SetBSON method during unmarshaling, and the object +// itself will not be changed as usual. +// +// If setting the value works, the method should return nil or alternatively +// ErrMgoSetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a non-nil error, the unmarshalling +// procedure will stop and error out with the provided value. +// +// This interface is generally useful in pointer receivers, since the method +// will want to change the receiver. A type field that implements the Setter +// interface doesn't have to be a pointer, though. +// +// For example: +// +// type MyString string +// +// func (s *MyString) SetBSON(raw bson.RawValue) error { +// return raw.Unmarshal(s) +// } +type setter interface { + SetBSON(raw RawValue) error +} + +// getter interface: a value implementing the bson.Getter interface will have its GetBSON +// method called when the given value has to be marshalled, and the result +// of this method will be marshaled in place of the actual object. +// +// If GetBSON returns return a non-nil error, the marshalling procedure +// will stop and error out with the provided value. +type getter interface { + GetBSON() (interface{}, error) +} + +// setterDecodeValue is the ValueDecoderFunc for Setter types. +func setterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { + if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) { + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + } + + if val.Kind() == reflect.Ptr && val.IsNil() { + if !val.CanSet() { + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + } + val.Set(reflect.New(val.Type().Elem())) + } + + if !val.Type().Implements(tSetter) { + if !val.CanAddr() { + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + } + val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. + } + + t, src, err := CopyValueToBytes(vr) + if err != nil { + return err + } + + m, ok := val.Interface().(setter) + if !ok { + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + } + if err := m.SetBSON(RawValue{Type: t, Value: src}); err != nil { + if !errors.Is(err, ErrMgoSetZero) { + return err + } + val.Set(reflect.Zero(val.Type())) + } + return nil +} + +// getterEncodeValue is the ValueEncoderFunc for Getter types. +func getterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + // Either val or a pointer to val must implement Getter + switch { + case !val.IsValid(): + return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} + case val.Type().Implements(tGetter): + // If Getter is implemented on a concrete type, make sure that val isn't a nil pointer + if isImplementationNil(val, tGetter) { + return vw.WriteNull() + } + case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr(): + val = val.Addr() + default: + return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} + } + + m, ok := val.Interface().(getter) + if !ok { + return vw.WriteNull() + } + x, err := m.GetBSON() + if err != nil { + return err + } + if x == nil { + return vw.WriteNull() + } + vv := reflect.ValueOf(x) + encoder, err := ec.Registry.LookupEncoder(vv.Type()) + if err != nil { + return err + } + return encoder.EncodeValue(ec, vw, vv) +} diff --git a/bson/objectid.go b/bson/objectid.go index 6d03c1310ab..ccb0b783382 100644 --- a/bson/objectid.go +++ b/bson/objectid.go @@ -35,7 +35,7 @@ var objectIDCounter = readRandomUint32() var processUnique = processUniqueBytes() var _ encoding.TextMarshaler = ObjectID{} -var _ encoding.TextUnmarshaler = (*ObjectID)(nil) +var _ encoding.TextUnmarshaler = &ObjectID{} // NewObjectID generates a new ObjectID. func NewObjectID() ObjectID { diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 5946b9cc9fb..2839efed839 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -10,29 +10,20 @@ import ( "reflect" ) -var _ ValueEncoder = &PointerCodec{} -var _ ValueDecoder = &PointerCodec{} +var ( + _ ValueEncoder = &pointerCodec{} + _ ValueDecoder = &pointerCodec{} +) -// PointerCodec is the Codec used for pointers. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// PointerCodec registered. -type PointerCodec struct { +// pointerCodec is the Codec used for pointers. +type pointerCodec struct { ecache typeEncoderCache dcache typeDecoderCache } -// NewPointerCodec returns a PointerCodec that has been initialized. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// PointerCodec registered. -func NewPointerCodec() *PointerCodec { - return &PointerCodec{} -} - // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil // or looking up an encoder for the type of value the pointer points to. -func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (pc *pointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Ptr { if !val.IsValid() { return vw.WriteNull() @@ -62,7 +53,7 @@ func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflec // DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and // using that to decode. If the BSON value is Null, this method will set the pointer to nil. -func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Ptr { return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } diff --git a/bson/primitive.go b/bson/primitive.go index 281d2335534..aac77fc3abd 100644 --- a/bson/primitive.go +++ b/bson/primitive.go @@ -205,16 +205,12 @@ type MaxKey struct{} // bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} type D []E -// Map creates a map from the elements of the D. -// -// Deprecated: Converting directly from a D to an M will not be supported in Go Driver 2.0. Instead, -// users should marshal the D to BSON using bson.Marshal and unmarshal it to M using bson.Unmarshal. -func (d D) Map() M { - m := make(M, len(d)) - for _, e := range d { - m[e.Key] = e.Value +func (d D) String() string { + b, err := MarshalExtJSON(d, true, false) + if err != nil { + return "" } - return m + return string(b) } // MarshalJSON encodes D into JSON. @@ -281,6 +277,14 @@ type E struct { // bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} type M map[string]interface{} +func (m M) String() string { + b, err := MarshalExtJSON(m, true, false) + if err != nil { + return "" + } + return string(b) +} + // An A is an ordered representation of a BSON array. // // Example usage: diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index 262645ce4c9..334549465e3 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -7,7 +7,6 @@ package bson import ( - "errors" "fmt" "reflect" ) @@ -15,38 +14,20 @@ import ( var tRawValue = reflect.TypeOf(RawValue{}) var tRaw = reflect.TypeOf(Raw(nil)) -// PrimitiveCodecs is a namespace for all of the default Codecs for the primitive types -// defined in this package. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -type PrimitiveCodecs struct{} - -// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs +// registerPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs // with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *RegistryBuilder) { - if rb == nil { - panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) - } - - rb. - RegisterTypeEncoder(tRawValue, ValueEncoderFunc(pc.RawValueEncodeValue)). - RegisterTypeEncoder(tRaw, ValueEncoderFunc(pc.RawEncodeValue)). - RegisterTypeDecoder(tRawValue, ValueDecoderFunc(pc.RawValueDecodeValue)). - RegisterTypeDecoder(tRaw, ValueDecoderFunc(pc.RawDecodeValue)) +func registerPrimitiveCodecs(reg *Registry) { + reg.RegisterTypeEncoder(tRawValue, ValueEncoderFunc(rawValueEncodeValue)) + reg.RegisterTypeEncoder(tRaw, ValueEncoderFunc(rawEncodeValue)) + reg.RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)) + reg.RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) } -// RawValueEncodeValue is the ValueEncoderFunc for RawValue. +// rawValueEncodeValue is the ValueEncoderFunc for RawValue. // // If the RawValue's Type is "invalid" and the RawValue's Value is not empty or // nil, then this method will return an error. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive -// encoders and decoders registered. -func (PrimitiveCodecs) RawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func rawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { return ValueEncoderError{ Name: "RawValueEncodeValue", @@ -64,11 +45,8 @@ func (PrimitiveCodecs) RawValueEncodeValue(_ EncodeContext, vw ValueWriter, val return copyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) } -// RawValueDecodeValue is the ValueDecoderFunc for RawValue. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// rawValueDecodeValue is the ValueDecoderFunc for RawValue. +func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRawValue { return ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val} } @@ -82,11 +60,8 @@ func (PrimitiveCodecs) RawValueDecodeValue(_ DecodeContext, vr ValueReader, val return nil } -// RawEncodeValue is the ValueEncoderFunc for Reader. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// rawEncodeValue is the ValueEncoderFunc for Reader. +func rawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } @@ -96,11 +71,8 @@ func (PrimitiveCodecs) RawEncodeValue(_ EncodeContext, vw ValueWriter, val refle return copyDocumentFromBytes(vw, rdr) } -// RawDecodeValue is the ValueDecoderFunc for Reader. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// rawDecodeValue is the ValueDecoderFunc for Reader. +func rawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRaw { return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index be3aeab9786..e9a74035512 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -32,8 +32,6 @@ func bytesFromDoc(doc interface{}) []byte { func TestPrimitiveValueEncoders(t *testing.T) { t.Parallel() - var pc PrimitiveCodecs - var wrong = func(string, string) string { return "wrong" } type subtest struct { @@ -52,7 +50,7 @@ func TestPrimitiveValueEncoders(t *testing.T) { }{ { "RawValueEncodeValue", - ValueEncoderFunc(pc.RawValueEncodeValue), + ValueEncoderFunc(rawValueEncodeValue), []subtest{ { "wrong type", @@ -100,7 +98,7 @@ func TestPrimitiveValueEncoders(t *testing.T) { }, { "RawEncodeValue", - ValueEncoderFunc(pc.RawEncodeValue), + ValueEncoderFunc(rawEncodeValue), []subtest{ { "wrong type", @@ -478,8 +476,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { } func TestPrimitiveValueDecoders(t *testing.T) { - var pc PrimitiveCodecs - var wrong = func(string, string) string { return "wrong" } const cansetreflectiontest = "cansetreflectiontest" @@ -500,7 +496,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { }{ { "RawValueDecodeValue", - ValueDecoderFunc(pc.RawValueDecodeValue), + ValueDecoderFunc(rawValueDecodeValue), []subtest{ { "wrong type", @@ -544,7 +540,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { }, { "RawDecodeValue", - ValueDecoderFunc(pc.RawDecodeValue), + ValueDecoderFunc(rawDecodeValue), []subtest{ { "wrong type", @@ -1066,8 +1062,8 @@ type testValueMarshaler struct { err error } -func (tvm testValueMarshaler) MarshalBSONValue() (Type, []byte, error) { - return tvm.t, tvm.buf, tvm.err +func (tvm testValueMarshaler) MarshalBSONValue() (byte, []byte, error) { + return byte(tvm.t), tvm.buf, tvm.err } type testValueUnmarshaler struct { @@ -1076,8 +1072,8 @@ type testValueUnmarshaler struct { err error } -func (tvu *testValueUnmarshaler) UnmarshalBSONValue(t Type, val []byte) error { - tvu.t, tvu.val = t, val +func (tvu *testValueUnmarshaler) UnmarshalBSONValue(t byte, val []byte) error { + tvu.t, tvu.val = Type(t), val return tvu.err } func (tvu testValueUnmarshaler) Equal(tvu2 testValueUnmarshaler) bool { diff --git a/bson/proxy.go b/bson/proxy.go deleted file mode 100644 index 1ccca6c2d15..00000000000 --- a/bson/proxy.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bson - -// Proxy is an interface implemented by types that cannot themselves be directly encoded. Types -// that implement this interface with have ProxyBSON called during the encoding process and that -// value will be encoded in place for the implementer. -type Proxy interface { - ProxyBSON() (interface{}, error) -} diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index f02fe8f3268..67444faa612 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -25,7 +25,7 @@ func TestRawValue(t *testing.T) { t.Run("Uses registry attached to value", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() val := RawValue{Type: TypeString, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -63,7 +63,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -114,7 +114,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + dc := DecodeContext{Registry: newTestRegistry()} var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} diff --git a/bson/reader.go b/bson/reader.go index 6d7cac48b0b..587982b338b 100644 --- a/bson/reader.go +++ b/bson/reader.go @@ -48,13 +48,11 @@ type ValueReader interface { ReadUndefined() error } -// BytesReader is a generic interface used to read BSON bytes from a -// ValueReader. This imterface is meant to be a superset of ValueReader, so that +// bytesReader is a generic interface used to read BSON bytes from a +// ValueReader. This interface is meant to be a superset of ValueReader, so that // types that implement ValueReader may also implement this interface. // // The bytes of the value will be appended to dst. -// -// Deprecated: BytesReader will not be supported in Go Driver 2.0. -type BytesReader interface { - ReadValueBytes(dst []byte) (Type, []byte, error) +type bytesReader interface { + readValueBytes(dst []byte) (Type, []byte, error) } diff --git a/bson/registry.go b/bson/registry.go index 74b99e93ab4..bcd3133445f 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -63,187 +63,6 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } -// ErrNotInterface is returned when the provided type is not an interface. -// -// Deprecated: ErrNotInterface will not be supported in Go Driver 2.0. -var ErrNotInterface = errors.New("The provided type is not an interface") - -// A RegistryBuilder is used to build a Registry. This type is not goroutine -// safe. -// -// Deprecated: Use Registry instead. -type RegistryBuilder struct { - registry *Registry -} - -// NewRegistryBuilder creates a new empty RegistryBuilder. -// -// Deprecated: Use NewRegistry instead. -func NewRegistryBuilder() *RegistryBuilder { - rb := &RegistryBuilder{ - registry: &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - }, - } - DefaultValueEncoders{}.RegisterDefaultEncoders(rb) - DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) - return rb -} - -// RegisterCodec will register the provided ValueCodec for the provided type. -// -// Deprecated: Use Registry.RegisterTypeEncoder and Registry.RegisterTypeDecoder instead. -func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder { - rb.RegisterTypeEncoder(t, codec) - rb.RegisterTypeDecoder(t, codec) - return rb -} - -// RegisterTypeEncoder will register the provided ValueEncoder for the provided type. -// -// The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered -// for a pointer to that type. -// -// If the given type is an interface, the encoder will be called when marshaling a type that is that interface. It -// will not be called when marshaling a non-interface type that implements the interface. -// -// Deprecated: Use Registry.RegisterTypeEncoder instead. -func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterTypeEncoder(t, enc) - return rb -} - -// RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when -// marshaling a type if the type implements t or a pointer to the type implements t. If the provided type is not -// an interface (i.e. t.Kind() != reflect.Interface), this method will panic. -// -// Deprecated: Use Registry.RegisterInterfaceEncoder instead. -func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterInterfaceEncoder(t, enc) - return rb -} - -// RegisterTypeDecoder will register the provided ValueDecoder for the provided type. -// -// The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered -// for a pointer to that type. -// -// If the given type is an interface, the decoder will be called when unmarshaling into a type that is that interface. -// It will not be called when unmarshaling into a non-interface type that implements the interface. -// -// Deprecated: Use Registry.RegisterTypeDecoder instead. -func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterTypeDecoder(t, dec) - return rb -} - -// RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when -// unmarshaling into a type if the type implements t or a pointer to the type implements t. If the provided type is not -// an interface (i.e. t.Kind() != reflect.Interface), this method will panic. -// -// Deprecated: Use Registry.RegisterInterfaceDecoder instead. -func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterInterfaceDecoder(t, dec) - return rb -} - -// RegisterEncoder registers the provided type and encoder pair. -// -// Deprecated: Use Registry.RegisterTypeEncoder or Registry.RegisterInterfaceEncoder instead. -func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - if t == tEmpty { - rb.registry.RegisterTypeEncoder(t, enc) - return rb - } - switch t.Kind() { - case reflect.Interface: - rb.registry.RegisterInterfaceEncoder(t, enc) - default: - rb.registry.RegisterTypeEncoder(t, enc) - } - return rb -} - -// RegisterDecoder registers the provided type and decoder pair. -// -// Deprecated: Use Registry.RegisterTypeDecoder or Registry.RegisterInterfaceDecoder instead. -func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - if t == nil { - rb.registry.RegisterTypeDecoder(t, dec) - return rb - } - if t == tEmpty { - rb.registry.RegisterTypeDecoder(t, dec) - return rb - } - switch t.Kind() { - case reflect.Interface: - rb.registry.RegisterInterfaceDecoder(t, dec) - default: - rb.registry.RegisterTypeDecoder(t, dec) - } - return rb -} - -// RegisterDefaultEncoder will register the provided ValueEncoder to the provided -// kind. -// -// Deprecated: Use Registry.RegisterKindEncoder instead. -func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterKindEncoder(kind, enc) - return rb -} - -// RegisterDefaultDecoder will register the provided ValueDecoder to the -// provided kind. -// -// Deprecated: Use Registry.RegisterKindDecoder instead. -func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterKindDecoder(kind, dec) - return rb -} - -// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this -// mapping is decoding situations where an empty interface is used and a default type needs to be -// created and decoded into. -// -// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON -// documents, a type map entry for TypeEmbeddedDocument should be registered. For example, to force BSON documents -// to decode to bson.Raw, use the following code: -// -// rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) -// -// Deprecated: Use Registry.RegisterTypeMapEntry instead. -func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *RegistryBuilder { - rb.registry.RegisterTypeMapEntry(bt, rt) - return rb -} - -// Build creates a Registry from the current state of this RegistryBuilder. -// -// Deprecated: Use NewRegistry instead. -func (rb *RegistryBuilder) Build() *Registry { - r := &Registry{ - interfaceEncoders: append([]interfaceValueEncoder(nil), rb.registry.interfaceEncoders...), - interfaceDecoders: append([]interfaceValueDecoder(nil), rb.registry.interfaceDecoders...), - typeEncoders: rb.registry.typeEncoders.Clone(), - typeDecoders: rb.registry.typeDecoders.Clone(), - kindEncoders: rb.registry.kindEncoders.Clone(), - kindDecoders: rb.registry.kindDecoders.Clone(), - } - rb.registry.typeMap.Range(func(k, v interface{}) bool { - if k != nil && v != nil { - r.typeMap.Store(k, v) - } - return true - }) - return r -} - // A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type // documentation for examples of registering various custom encoders and decoders. A Registry can // have four main types of codecs: @@ -289,7 +108,16 @@ type Registry struct { // NewRegistry creates a new empty Registry. func NewRegistry() *Registry { - return NewRegistryBuilder().Build() + reg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultEncoders(reg) + registerDefaultDecoders(reg) + registerPrimitiveCodecs(reg) + return reg } // RegisterTypeEncoder registers the provided ValueEncoder for the provided type. diff --git a/bson/registry_test.go b/bson/registry_test.go index 2bc87364d35..184501933a1 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -15,15 +15,13 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" ) -// newTestRegistryBuilder creates a new empty Registry. -func newTestRegistryBuilder() *RegistryBuilder { - return &RegistryBuilder{ - registry: &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - }, +// newTestRegistry creates a new Registry. +func newTestRegistry() *Registry { + return &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), } } @@ -45,12 +43,11 @@ func TestRegistryBuilder(t *testing.T) { {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, } - rb := newTestRegistryBuilder() + reg := newTestRegistry() for _, ip := range ips { - rb.RegisterHookEncoder(ip.i, ip.ve) + reg.RegisterInterfaceEncoder(ip.i, ip.ve) } - reg := rb.Build() got := reg.interfaceEncoders if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want) @@ -58,11 +55,11 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("type", func(t *testing.T) { ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - rb := newTestRegistryBuilder(). - RegisterTypeEncoder(reflect.TypeOf(ft1), fc1). - RegisterTypeEncoder(reflect.TypeOf(ft2), fc2). - RegisterTypeEncoder(reflect.TypeOf(ft1), fc3). - RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) + reg := newTestRegistry() + reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) + reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) + reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) + reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) want := []struct { t reflect.Type c ValueEncoder @@ -72,7 +69,6 @@ func TestRegistryBuilder(t *testing.T) { {reflect.TypeOf(ft4), fc4}, } - reg := rb.Build() got := reg.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c @@ -87,11 +83,11 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("kind", func(t *testing.T) { k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - rb := newTestRegistryBuilder(). - RegisterDefaultEncoder(k1, fc1). - RegisterDefaultEncoder(k2, fc2). - RegisterDefaultEncoder(k1, fc3). - RegisterDefaultEncoder(k4, fc4) + reg := newTestRegistry() + reg.RegisterKindEncoder(k1, fc1) + reg.RegisterKindEncoder(k2, fc2) + reg.RegisterKindEncoder(k1, fc3) + reg.RegisterKindEncoder(k4, fc4) want := []struct { k reflect.Kind c ValueEncoder @@ -101,7 +97,6 @@ func TestRegistryBuilder(t *testing.T) { {k4, fc4}, } - reg := rb.Build() got := reg.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c @@ -118,16 +113,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("MapCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Map, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Map, codec) if reg.kindEncoders.get(reflect.Map) != codec { t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) } - rb.RegisterDefaultEncoder(reflect.Map, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Map, codec2) if reg.kindEncoders.get(reflect.Map) != codec2 { t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) } @@ -135,16 +128,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("StructCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Struct, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Struct, codec) if reg.kindEncoders.get(reflect.Struct) != codec { t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) } - rb.RegisterDefaultEncoder(reflect.Struct, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Struct, codec2) if reg.kindEncoders.get(reflect.Struct) != codec2 { t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) } @@ -152,16 +143,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("SliceCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Slice, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Slice, codec) if reg.kindEncoders.get(reflect.Slice) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) } - rb.RegisterDefaultEncoder(reflect.Slice, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Slice, codec2) if reg.kindEncoders.get(reflect.Slice) != codec2 { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) } @@ -169,16 +158,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("ArrayCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Array, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Array, codec) if reg.kindEncoders.get(reflect.Array) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) } - rb.RegisterDefaultEncoder(reflect.Array, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Array, codec2) if reg.kindEncoders.get(reflect.Array) != codec2 { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) } @@ -208,31 +195,30 @@ func TestRegistryBuilder(t *testing.T) { ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = NewPointerCodec() + pc = &pointerCodec{} ) - reg := newTestRegistryBuilder(). - RegisterTypeEncoder(ft1, fc1). - RegisterTypeEncoder(ft2, fc2). - RegisterTypeEncoder(ti1, fc1). - RegisterDefaultEncoder(reflect.Struct, fsc). - RegisterDefaultEncoder(reflect.Slice, fslcc). - RegisterDefaultEncoder(reflect.Array, fslcc). - RegisterDefaultEncoder(reflect.Map, fmc). - RegisterDefaultEncoder(reflect.Ptr, pc). - RegisterTypeDecoder(ft1, fc1). - RegisterTypeDecoder(ft2, fc2). - RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder - RegisterDefaultDecoder(reflect.Struct, fsc). - RegisterDefaultDecoder(reflect.Slice, fslcc). - RegisterDefaultDecoder(reflect.Array, fslcc). - RegisterDefaultDecoder(reflect.Map, fmc). - RegisterDefaultDecoder(reflect.Ptr, pc). - RegisterHookEncoder(ti2, fc2). - RegisterHookDecoder(ti2, fc2). - RegisterHookEncoder(ti3, fc3). - RegisterHookDecoder(ti3, fc3). - Build() + reg := newTestRegistry() + reg.RegisterTypeEncoder(ft1, fc1) + reg.RegisterTypeEncoder(ft2, fc2) + reg.RegisterTypeEncoder(ti1, fc1) + reg.RegisterKindEncoder(reflect.Struct, fsc) + reg.RegisterKindEncoder(reflect.Slice, fslcc) + reg.RegisterKindEncoder(reflect.Array, fslcc) + reg.RegisterKindEncoder(reflect.Map, fmc) + reg.RegisterKindEncoder(reflect.Ptr, pc) + reg.RegisterTypeDecoder(ft1, fc1) + reg.RegisterTypeDecoder(ft2, fc2) + reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder + reg.RegisterKindDecoder(reflect.Struct, fsc) + reg.RegisterKindDecoder(reflect.Slice, fslcc) + reg.RegisterKindDecoder(reflect.Array, fslcc) + reg.RegisterKindDecoder(reflect.Map, fmc) + reg.RegisterKindDecoder(reflect.Ptr, pc) + reg.RegisterInterfaceEncoder(ti2, fc2) + reg.RegisterInterfaceDecoder(ti2, fc2) + reg.RegisterInterfaceEncoder(ti3, fc3) + reg.RegisterInterfaceDecoder(ti3, fc3) testCases := []struct { name string @@ -348,7 +334,7 @@ func TestRegistryBuilder(t *testing.T) { } allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *PointerCodec) bool { return true } + comparepc := func(pc1, pc2 *pointerCodec) bool { return true } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Run("Encoder", func(t *testing.T) { @@ -409,10 +395,9 @@ func TestRegistryBuilder(t *testing.T) { }) }) t.Run("Type Map", func(t *testing.T) { - reg := newTestRegistryBuilder(). - RegisterTypeMapEntry(TypeString, reflect.TypeOf("")). - RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))). - Build() + reg := newTestRegistry() + reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) + reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) var got, want reflect.Type @@ -466,7 +451,7 @@ func TestRegistry(t *testing.T) { {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, } - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() for _, ip := range ips { reg.RegisterInterfaceEncoder(ip.i, ip.ve) } @@ -479,7 +464,7 @@ func TestRegistry(t *testing.T) { t.Parallel() ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) @@ -509,7 +494,7 @@ func TestRegistry(t *testing.T) { t.Parallel() k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(k1, fc1) reg.RegisterKindEncoder(k2, fc2) reg.RegisterKindEncoder(k1, fc3) @@ -543,7 +528,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Map, codec) if reg.kindEncoders.get(reflect.Map) != codec { t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) @@ -558,7 +543,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Struct, codec) if reg.kindEncoders.get(reflect.Struct) != codec { t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) @@ -573,7 +558,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Slice, codec) if reg.kindEncoders.get(reflect.Slice) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) @@ -588,7 +573,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Array, codec) if reg.kindEncoders.get(reflect.Array) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) @@ -625,10 +610,10 @@ func TestRegistry(t *testing.T) { ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = NewPointerCodec() + pc = &pointerCodec{} ) - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterTypeEncoder(ft1, fc1) reg.RegisterTypeEncoder(ft2, fc2) reg.RegisterTypeEncoder(ti1, fc1) @@ -764,7 +749,7 @@ func TestRegistry(t *testing.T) { } allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *PointerCodec) bool { return true } + comparepc := func(pc1, pc2 *pointerCodec) bool { return true } for _, tc := range testCases { tc := tc @@ -869,7 +854,7 @@ func TestRegistry(t *testing.T) { }) t.Run("Type Map", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) diff --git a/bson/slice_codec.go b/bson/slice_codec.go index 52449239b98..c8719dcc18d 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -10,45 +10,22 @@ import ( "errors" "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -var defaultSliceCodec = NewSliceCodec() - -// SliceCodec is the Codec used for slice values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// SliceCodec registered. -type SliceCodec struct { - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of +// sliceCodec is the Codec used for slice values. +type sliceCodec struct { + // encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of // BSON null. - // - // Deprecated: Use bson.Encoder.NilSliceAsEmpty instead. - EncodeNilAsEmpty bool -} - -// NewSliceCodec returns a MapCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// SliceCodec registered. -func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec { - sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...) - - codec := SliceCodec{} - if sliceOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *sliceOpt.EncodeNilAsEmpty - } - return &codec + encodeNilAsEmpty bool } // EncodeValue is the ValueEncoder for slice types. -func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } - if val.IsNil() && !sc.EncodeNilAsEmpty && !ec.nilSliceAsEmpty { + if val.IsNil() && !sc.encodeNilAsEmpty && !ec.nilSliceAsEmpty { return vw.WriteNull() } @@ -90,7 +67,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -117,7 +94,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } // DecodeValue is the ValueDecoder for slice types. -func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Slice { return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } @@ -175,10 +152,9 @@ func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - dc.Ancestor = val.Type() - elemsFunc = defaultValueDecoders.decodeD + elemsFunc = decodeD default: - elemsFunc = defaultValueDecoders.decodeDefault + elemsFunc = decodeDefault } elems, err := elemsFunc(dc, vr, val) diff --git a/bson/string_codec.go b/bson/string_codec.go index 50fb9229fea..fcda72af902 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -9,42 +9,22 @@ package bson import ( "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// StringCodec is the Codec used for string values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StringCodec registered. -type StringCodec struct { +// stringCodec is the Codec used for string values. +type stringCodec struct { // DecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. // If false, a string made from the raw object ID bytes will be used. Defaults to true. - // - // Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. - DecodeObjectIDAsHex bool + decodeObjectIDAsHex bool } -var ( - defaultStringCodec = NewStringCodec() - - // Assert that defaultStringCodec satisfies the typeDecoder interface, which allows it to be - // used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultStringCodec -) - -// NewStringCodec returns a StringCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StringCodec registered. -func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec { - stringOpt := bsonoptions.MergeStringCodecOptions(opts...) - return &StringCodec{*stringOpt.DecodeObjectIDAsHex} -} +// Assert that stringCodec satisfies the typeDecoder interface, which allows it to be +// used by collection type decoders (e.g. map, slice, etc) to set individual values in a +// collection. +var _ typeDecoder = &stringCodec{} // EncodeValue is the ValueEncoder for string types. -func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *stringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", @@ -56,7 +36,7 @@ func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect. return vw.WriteString(val.String()) } -func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (sc *stringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.String { return emptyValue, ValueDecoderError{ Name: "StringDecodeValue", @@ -78,7 +58,7 @@ func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Typ if err != nil { return emptyValue, err } - if sc.DecodeObjectIDAsHex { + if sc.decodeObjectIDAsHex { str = oid.Hex() } else { // TODO(GODRIVER-2796): Return an error here instead of decoding to a garbled string. @@ -115,7 +95,7 @@ func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Typ } // DecodeValue is the ValueDecoder for string types. -func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *stringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.String { return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} } diff --git a/bson/string_codec_test.go b/bson/string_codec_test.go index 75ace60c5d5..d44c0426533 100644 --- a/bson/string_codec_test.go +++ b/bson/string_codec_test.go @@ -10,7 +10,6 @@ import ( "reflect" "testing" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" ) @@ -20,21 +19,17 @@ func TestStringCodec(t *testing.T) { byteArray := [12]byte(oid) reader := &valueReaderWriter{BSONType: TypeObjectID, Return: oid} testCases := []struct { - name string - opts *bsonoptions.StringCodecOptions - hex bool - result string + name string + stringCodec *stringCodec + result string }{ - {"default", bsonoptions.StringCodec(), true, oid.Hex()}, - {"true", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(true), true, oid.Hex()}, - {"false", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false), false, string(byteArray[:])}, + {"true", &stringCodec{decodeObjectIDAsHex: true}, oid.Hex()}, + {"false", &stringCodec{decodeObjectIDAsHex: false}, string(byteArray[:])}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stringCodec := NewStringCodec(tc.opts) - actual := reflect.New(reflect.TypeOf("")).Elem() - err := stringCodec.DecodeValue(DecodeContext{}, reader, actual) + err := tc.stringCodec.DecodeValue(DecodeContext{}, reader, actual) assert.Nil(t, err, "StringCodec.DecodeValue error: %v", err) actualString := actual.Interface().(string) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 14337c7a2ee..18cd40140f3 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -14,8 +14,6 @@ import ( "strings" "sync" "time" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. @@ -49,86 +47,53 @@ func (de *DecodeError) Keys() []string { return reversedKeys } -// StructCodec is the Codec used for struct values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StructCodec registered. -type StructCodec struct { - cache sync.Map // map[reflect.Type]*structDescription - parser StructTagParser +// mapElementsEncoder handles encoding of the values of an inline map. +type mapElementsEncoder interface { + encodeMapElements(EncodeContext, DocumentWriter, reflect.Value, func(string) bool) error +} + +// structCodec is the Codec used for struct values. +type structCodec struct { + cache sync.Map // map[reflect.Type]*structDescription + inlineMapEncoder mapElementsEncoder // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the // destination value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: Use bson.Decoder.ZeroStructs instead. - DecodeZeroStruct bool + decodeZeroStruct bool // DecodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the // destination value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. - DecodeDeepZeroInline bool + decodeDeepZeroInline bool // EncodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g. // MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag // option is set. - // - // Deprecated: Use bson.Encoder.OmitZeroStruct instead. - EncodeOmitDefaultStruct bool + encodeOmitDefaultStruct bool // AllowUnexportedFields allows encoding and decoding values from un-exported struct fields. - // - // Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be - // supported in Go Driver 2.0. - AllowUnexportedFields bool + allowUnexportedFields bool // OverwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is // a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The // default value is true. - // - // Deprecated: Use bson.Encoder.ErrorOnInlineDuplicates instead. - OverwriteDuplicatedInlinedFields bool + overwriteDuplicatedInlinedFields bool } -var _ ValueEncoder = &StructCodec{} -var _ ValueDecoder = &StructCodec{} - -// NewStructCodec returns a StructCodec that uses p for struct tag parsing. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StructCodec registered. -func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { - if p == nil { - return nil, errors.New("a StructTagParser must be provided to NewStructCodec") - } - - structOpt := bsonoptions.MergeStructCodecOptions(opts...) - - codec := &StructCodec{ - parser: p, - } +var ( + _ ValueEncoder = &structCodec{} + _ ValueDecoder = &structCodec{} +) - if structOpt.DecodeZeroStruct != nil { - codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct - } - if structOpt.DecodeDeepZeroInline != nil { - codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline - } - if structOpt.EncodeOmitDefaultStruct != nil { - codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct - } - if structOpt.OverwriteDuplicatedInlinedFields != nil { - codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields - } - if structOpt.AllowUnexportedFields != nil { - codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields +// newStructCodec returns a StructCodec that uses p for struct tag parsing. +func newStructCodec(elemEncoder mapElementsEncoder) *structCodec { + return &structCodec{ + inlineMapEncoder: elemEncoder, + overwriteDuplicatedInlinedFields: true, } - - return codec, nil } // EncodeValue handles encoding generic struct types. -func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } @@ -153,7 +118,7 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect } } - desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv) + desc.encoder, rv, err = lookupElementEncoder(ec, desc.encoder, rv) if err != nil && !errors.Is(err, errInvalidValue) { return err @@ -181,14 +146,12 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect encoder := desc.encoder var empty bool - if cz, ok := encoder.(CodecZeroer); ok { - empty = cz.IsTypeZero(rv.Interface()) - } else if rv.Kind() == reflect.Interface { + if rv.Kind() == reflect.Interface { // isEmpty will not treat an interface rv as an interface, so we need to check for the // nil interface separately. empty = rv.IsNil() } else { - empty = isEmpty(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) + empty = isEmpty(rv, sc.encodeOmitDefaultStruct || ec.omitZeroStruct) } if desc.omitEmpty && empty { continue @@ -223,7 +186,10 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect return exists } - return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn) + err = sc.inlineMapEncoder.encodeMapElements(ec, dw, rv, collisionFn) + if err != nil { + return err + } } return dw.WriteDocumentEnd() @@ -245,7 +211,7 @@ func newDecodeError(key string, original error) error { // DecodeValue implements the Codec interface. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. -func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Struct { return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } @@ -275,10 +241,10 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect return err } - if sc.DecodeZeroStruct || dc.zeroStructs { + if sc.decodeZeroStruct || dc.zeroStructs { val.Set(reflect.Zero(val.Type())) } - if sc.DecodeDeepZeroInline && sd.inline { + if sc.decodeDeepZeroInline && sd.inline { val.Set(deepZero(val.Type())) } @@ -330,7 +296,6 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect } elem := reflect.New(inlineMap.Type().Elem()).Elem() - dc.Ancestor = inlineMap.Type() err = decoder.DecodeValue(dc, vr, elem) if err != nil { return err @@ -474,7 +439,7 @@ func (bi byIndex) Less(i, j int) bool { return len(bi[i].inline) < len(bi[j].inline) } -func (sc *StructCodec) describeStruct( +func (sc *structCodec) describeStruct( r *Registry, t reflect.Type, useJSONStructTags bool, @@ -497,7 +462,7 @@ func (sc *StructCodec) describeStruct( return ds, nil } -func (sc *StructCodec) describeStructSlow( +func (sc *structCodec) describeStructSlow( r *Registry, t reflect.Type, useJSONStructTags bool, @@ -513,7 +478,7 @@ func (sc *StructCodec) describeStructSlow( var fields []fieldDescription for i := 0; i < numFields; i++ { sf := t.Field(i) - if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { + if sf.PkgPath != "" && (!sc.allowUnexportedFields || !sf.Anonymous) { // field is private or unexported fields aren't allowed, ignore continue } @@ -535,13 +500,13 @@ func (sc *StructCodec) describeStructSlow( decoder: decoder, } - var stags StructTags + var stags *structTags // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser // instead of the parser defined on the codec. if useJSONStructTags { - stags, err = JSONFallbackStructTagParser.ParseStructTags(sf) + stags, err = parseJSONStructTags(sf) } else { - stags, err = sc.parser.ParseStructTags(sf) + stags, err = parseStructTags(sf) } if err != nil { return nil, err @@ -624,7 +589,7 @@ func (sc *StructCodec) describeStructSlow( continue } dominant, ok := dominantField(fields[i : i+advance]) - if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates { + if !ok || !sc.overwriteDuplicatedInlinedFields || errorOnDuplicates { return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name) } sd.fl = append(sd.fl, dominant) diff --git a/bson/struct_codec_test.go b/bson/struct_codec_test.go index d1f7e5373a7..156535b14e1 100644 --- a/bson/struct_codec_test.go +++ b/bson/struct_codec_test.go @@ -14,13 +14,13 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" ) -var _ Zeroer = zeroer{} +var _ Zeroer = testZeroer{} -type zeroer struct { +type testZeroer struct { val int } -func (z zeroer) IsZero() bool { +func (z testZeroer) IsZero() bool { return z.val != 0 } @@ -84,22 +84,22 @@ func TestIsZero(t *testing.T) { }, { description: "zero struct that implements Zeroer", - value: zeroer{}, + value: testZeroer{}, want: false, }, { description: "non-zero struct that implements Zeroer", - value: &zeroer{val: 1}, + value: &testZeroer{val: 1}, want: true, }, { description: "pointer to zero struct that implements Zeroer", - value: &zeroer{}, + value: &testZeroer{}, want: false, }, { description: "pointer to non-zero struct that implements Zeroer", - value: zeroer{val: 1}, + value: testZeroer{val: 1}, want: true, }, { diff --git a/bson/struct_tag_parser.go b/bson/struct_tag_parser.go index d116c140405..26773a39e9e 100644 --- a/bson/struct_tag_parser.go +++ b/bson/struct_tag_parser.go @@ -11,25 +11,7 @@ import ( "strings" ) -// StructTagParser returns the struct tags for a given struct field. -// -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTagParser interface { - ParseStructTags(reflect.StructField) (StructTags, error) -} - -// StructTagParserFunc is an adapter that allows a generic function to be used -// as a StructTagParser. -// -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTagParserFunc func(reflect.StructField) (StructTags, error) - -// ParseStructTags implements the StructTagParser interface. -func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) { - return stpf(sf) -} - -// StructTags represents the struct tag fields that the StructCodec uses during +// structTags represents the struct tag fields that the StructCodec uses during // the encoding and decoding process. // // In the case of a struct, the lowercased field name is used as the key for each exported @@ -55,7 +37,7 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT // for the name. // // Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTags struct { +type structTags struct { Name string OmitEmpty bool MinSize bool @@ -89,9 +71,7 @@ type StructTags struct { // A struct tag either consisting entirely of '-' or with a bson key with a // value consisting entirely of '-' will return a StructTags with Skip true and // the remaining fields will be their default values. -// -// Deprecated: DefaultStructTagParser will be removed in Go Driver 2.0. -var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { +func parseStructTags(sf reflect.StructField) (*structTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { @@ -100,11 +80,27 @@ var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (S return parseTags(key, tag) } -func parseTags(key string, tag string) (StructTags, error) { - var st StructTags +// jsonStructTagParser has the same behavior as DefaultStructTagParser +// but will also fallback to parsing the json tag instead on a field where the +// bson tag isn't available. +func parseJSONStructTags(sf reflect.StructField) (*structTags, error) { + key := strings.ToLower(sf.Name) + tag, ok := sf.Tag.Lookup("bson") + if !ok { + tag, ok = sf.Tag.Lookup("json") + } + if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { + tag = string(sf.Tag) + } + + return parseTags(key, tag) +} + +func parseTags(key string, tag string) (*structTags, error) { + var st structTags if tag == "-" { st.Skip = true - return st, nil + return &st, nil } for idx, str := range strings.Split(tag, ",") { @@ -125,24 +121,5 @@ func parseTags(key string, tag string) (StructTags, error) { st.Name = key - return st, nil -} - -// JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser -// but will also fallback to parsing the json tag instead on a field where the -// bson tag isn't available. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] and -// [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. -var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { - key := strings.ToLower(sf.Name) - tag, ok := sf.Tag.Lookup("bson") - if !ok { - tag, ok = sf.Tag.Lookup("json") - } - if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { - tag = string(sf.Tag) - } - - return parseTags(key, tag) + return &st, nil } diff --git a/bson/struct_tag_parser_test.go b/bson/struct_tag_parser_test.go index b03815488ac..3592761e891 100644 --- a/bson/struct_tag_parser_test.go +++ b/bson/struct_tag_parser_test.go @@ -17,134 +17,134 @@ func TestStructTagParsers(t *testing.T) { testCases := []struct { name string sf reflect.StructField - want StructTags - parser StructTagParserFunc + want *structTags + parser func(reflect.StructField) (*structTags, error) }{ { "default no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, - StructTags{Name: "bar"}, - DefaultStructTagParser, + &structTags{Name: "bar"}, + parseStructTags, }, { "default empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, - StructTags{Name: "foo"}, - DefaultStructTagParser, + &structTags{Name: "foo"}, + parseStructTags, }, { "default tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, - StructTags{Skip: true}, - DefaultStructTagParser, + &structTags{Skip: true}, + parseStructTags, }, { "default bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, - StructTags{Skip: true}, - DefaultStructTagParser, + &structTags{Skip: true}, + parseStructTags, }, { "default all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseStructTags, }, { "default all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseStructTags, }, { "default bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseStructTags, }, { "default bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseStructTags, }, { "default ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, - StructTags{Name: "foo"}, - DefaultStructTagParser, + &structTags{Name: "foo"}, + parseStructTags, }, { "JSONFallback no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, - StructTags{Name: "bar"}, - JSONFallbackStructTagParser, + &structTags{Name: "bar"}, + parseStructTags, }, { "JSONFallback empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, - StructTags{Name: "foo"}, - JSONFallbackStructTagParser, + &structTags{Name: "foo"}, + parseJSONStructTags, }, { "JSONFallback tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, - StructTags{Skip: true}, - JSONFallbackStructTagParser, + &structTags{Skip: true}, + parseJSONStructTags, }, { "JSONFallback bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, - StructTags{Skip: true}, - JSONFallbackStructTagParser, + &structTags{Skip: true}, + parseJSONStructTags, }, { "JSONFallback all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseJSONStructTags, }, { "JSONFallback all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseJSONStructTags, }, { "JSONFallback bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseJSONStructTags, }, { "JSONFallback bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseJSONStructTags, }, { "JSONFallback json tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseJSONStructTags, }, { "JSONFallback json tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, + parseJSONStructTags, }, { "JSONFallback bson tag overrides other tags", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar" json:"qux,truncate"`)}, - StructTags{Name: "bar"}, - JSONFallbackStructTagParser, + &structTags{Name: "bar"}, + parseJSONStructTags, }, { "JSONFallback ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, - StructTags{Name: "foo"}, - JSONFallbackStructTagParser, + &structTags{Name: "foo"}, + parseJSONStructTags, }, } diff --git a/bson/time_codec.go b/bson/time_codec.go index a168d1e7692..1c00374c19d 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -10,48 +10,23 @@ import ( "fmt" "reflect" "time" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) const ( timeFormatString = "2006-01-02T15:04:05.999Z07:00" ) -// TimeCodec is the Codec used for time.Time values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// TimeCodec registered. -type TimeCodec struct { - // UseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. - // - // Deprecated: Use bson.Decoder.UseLocalTimeZone instead. - UseLocalTimeZone bool +// timeCodec is the Codec used for time.Time values. +type timeCodec struct { + // useLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. + useLocalTimeZone bool } -var ( - defaultTimeCodec = NewTimeCodec() - - // Assert that defaultTimeCodec satisfies the typeDecoder interface, which allows it to be used - // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. - _ typeDecoder = defaultTimeCodec -) - -// NewTimeCodec returns a TimeCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// TimeCodec registered. -func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { - timeOpt := bsonoptions.MergeTimeCodecOptions(opts...) - - codec := TimeCodec{} - if timeOpt.UseLocalTimeZone != nil { - codec.UseLocalTimeZone = *timeOpt.UseLocalTimeZone - } - return &codec -} +// Assert that timeCodec satisfies the typeDecoder interface, which allows it to be used +// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. +var _ typeDecoder = &timeCodec{} -func (tc *TimeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tTime { return emptyValue, ValueDecoderError{ Name: "TimeDecodeValue", @@ -102,14 +77,14 @@ func (tc *TimeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) } - if !tc.UseLocalTimeZone && !dc.useLocalTimeZone { + if !tc.useLocalTimeZone && !dc.useLocalTimeZone { timeVal = timeVal.UTC() } return reflect.ValueOf(timeVal), nil } // DecodeValue is the ValueDecoderFunc for time.Time. -func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTime { return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} } @@ -124,7 +99,7 @@ func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *TimeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/bson/time_codec_test.go b/bson/time_codec_test.go index 1f185692dae..1bb35bafd3b 100644 --- a/bson/time_codec_test.go +++ b/bson/time_codec_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -22,20 +21,18 @@ func TestTimeCodec(t *testing.T) { t.Run("UseLocalTimeZone", func(t *testing.T) { reader := &valueReaderWriter{BSONType: TypeDateTime, Return: now.UnixNano() / int64(time.Millisecond)} testCases := []struct { - name string - opts *bsonoptions.TimeCodecOptions - utc bool + name string + timeCodec *timeCodec + utc bool }{ - {"default", bsonoptions.TimeCodec(), true}, - {"false", bsonoptions.TimeCodec().SetUseLocalTimeZone(false), true}, - {"true", bsonoptions.TimeCodec().SetUseLocalTimeZone(true), false}, + {"default", &timeCodec{}, true}, + {"false", &timeCodec{useLocalTimeZone: false}, true}, + {"true", &timeCodec{useLocalTimeZone: true}, false}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - timeCodec := NewTimeCodec(tc.opts) - actual := reflect.New(reflect.TypeOf(now)).Elem() - err := timeCodec.DecodeValue(DecodeContext{}, reader, actual) + err := tc.timeCodec.DecodeValue(DecodeContext{}, reader, actual) assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) @@ -69,7 +66,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := defaultTimeCodec.DecodeValue(DecodeContext{}, tc.reader, actual) + err := (&timeCodec{}).DecodeValue(DecodeContext{}, tc.reader, actual) assert.Nil(t, err, "DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) diff --git a/bson/types.go b/bson/types.go index 981a121356b..1213f8ddc6d 100644 --- a/bson/types.go +++ b/bson/types.go @@ -92,7 +92,6 @@ var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem() var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem() var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() -var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem() var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem() var tBinary = reflect.TypeOf(Binary{}) diff --git a/bson/uint_codec.go b/bson/uint_codec.go index 73bc01966e9..b8b4be398c9 100644 --- a/bson/uint_codec.go +++ b/bson/uint_codec.go @@ -10,46 +10,21 @@ import ( "fmt" "math" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// UIntCodec is the Codec used for uint values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// UIntCodec registered. -type UIntCodec struct { - // EncodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the +// uintCodec is the Codec used for uint values. +type uintCodec struct { + // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. - // - // Deprecated: Use bson.Encoder.IntMinSize instead. - EncodeToMinSize bool + encodeToMinSize bool } -var ( - defaultUIntCodec = NewUIntCodec() - - // Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used - // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. - _ typeDecoder = defaultUIntCodec -) - -// NewUIntCodec returns a UIntCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// UIntCodec registered. -func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec { - uintOpt := bsonoptions.MergeUIntCodecOptions(opts...) - - codec := UIntCodec{} - if uintOpt.EncodeToMinSize != nil { - codec.EncodeToMinSize = *uintOpt.EncodeToMinSize - } - return &codec -} +// Assert that uintCodec satisfies the typeDecoder interface, which allows it to be used +// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. +var _ typeDecoder = &uintCodec{} // EncodeValue is the ValueEncoder for uint types. -func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Uint8, reflect.Uint16: return vw.WriteInt32(int32(val.Uint())) @@ -57,7 +32,7 @@ func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. u64 := val.Uint() // If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64) + useMinSize := ec.MinSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64) if u64 <= math.MaxInt32 && useMinSize { return vw.WriteInt32(int32(u64)) @@ -75,7 +50,7 @@ func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. } } -func (uic *UIntCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 var err error switch vrType := vr.Type(); vrType { @@ -148,11 +123,15 @@ func (uic *UIntCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Typ return reflect.ValueOf(uint64(i64)), nil case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + if i64 < 0 { + return emptyValue, fmt.Errorf("%d overflows uint", i64) + } + v := uint64(i64) + if v > math.MaxUint { // Can we fit this inside of an uint return emptyValue, fmt.Errorf("%d overflows uint", i64) } - return reflect.ValueOf(uint(i64)), nil + return reflect.ValueOf(uint(v)), nil default: return emptyValue, ValueDecoderError{ Name: "UintDecodeValue", @@ -163,7 +142,7 @@ func (uic *UIntCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Typ } // DecodeValue is the ValueDecoder for uint types. -func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (uic *uintCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ Name: "UintDecodeValue", diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 7caadc5dbc1..49ed348ea4b 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -31,7 +31,7 @@ type Unmarshaler interface { // document. To create custom BSON unmarshaling behavior for an entire BSON // document, implement the Unmarshaler interface instead. type ValueUnmarshaler interface { - UnmarshalBSONValue(Type, []byte) error + UnmarshalBSONValue(typ byte, data []byte) error } // Unmarshal parses the BSON-encoded data and stores the result in the value diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index 8d9dfb53511..fd379b5daa5 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -76,7 +76,7 @@ func TestUnmarshalValue(t *testing.T) { }, } reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), NewSliceCodec()) + reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) for _, tc := range testCases { tc := tc @@ -111,7 +111,7 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { }, } reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), NewSliceCodec()) + reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { diff --git a/bson/value_reader.go b/bson/value_reader.go index 27265413056..039cd1b2c7b 100644 --- a/bson/value_reader.go +++ b/bson/value_reader.go @@ -16,7 +16,7 @@ import ( "sync" ) -var _ ValueReader = (*valueReader)(nil) +var _ ValueReader = &valueReader{} var vrPool = sync.Pool{ New: func() interface{} { @@ -304,7 +304,7 @@ func (vr *valueReader) nextElementLength() (int32, error) { return length, err } -func (vr *valueReader) ReadValueBytes(dst []byte) (Type, []byte, error) { +func (vr *valueReader) readValueBytes(dst []byte) (Type, []byte, error) { switch vr.stack[vr.frame].mode { case mTopLevel: length, err := vr.peekLength() @@ -839,7 +839,7 @@ func (vr *valueReader) peekLength() (int32, error) { } idx := vr.offset - return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil + return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil } func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } @@ -851,7 +851,7 @@ func (vr *valueReader) readi32() (int32, error) { idx := vr.offset vr.offset += 4 - return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil + return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil } func (vr *valueReader) readu32() (uint32, error) { @@ -861,7 +861,7 @@ func (vr *valueReader) readu32() (uint32, error) { idx := vr.offset vr.offset += 4 - return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil + return binary.LittleEndian.Uint32(vr.d[idx:]), nil } func (vr *valueReader) readi64() (int64, error) { @@ -871,8 +871,7 @@ func (vr *valueReader) readi64() (int64, error) { idx := vr.offset vr.offset += 8 - return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 | - int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil + return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil } func (vr *valueReader) readu64() (uint64, error) { @@ -882,6 +881,5 @@ func (vr *valueReader) readu64() (uint64, error) { idx := vr.offset vr.offset += 8 - return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 | - uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil + return binary.LittleEndian.Uint64(vr.d[idx:]), nil } diff --git a/bson/value_reader_test.go b/bson/value_reader_test.go index 21d42e74db1..c1c5de5ef55 100644 --- a/bson/value_reader_test.go +++ b/bson/value_reader_test.go @@ -1430,7 +1430,7 @@ func TestValueReader(t *testing.T) { offset: tc.startingOffset, } - _, got, err := vr.ReadValueBytes(nil) + _, got, err := vr.readValueBytes(nil) if !errequal(t, err, tc.err) { t.Errorf("Did not receive expected error; got %v; want %v", err, tc.err) } @@ -1481,7 +1481,7 @@ func TestValueReader(t *testing.T) { }, frame: 0, } - gotType, got, gotErr := vr.ReadValueBytes(nil) + gotType, got, gotErr := vr.readValueBytes(nil) if !errors.Is(gotErr, tc.wantErr) { t.Errorf("Did not receive expected error. got %v; want %v", gotErr, tc.wantErr) } @@ -1510,7 +1510,7 @@ func TestValueReader(t *testing.T) { vr := &valueReader{stack: []vrState{{mode: mTopLevel}, {mode: mDocument}}, frame: 1} wanterr := (&valueReader{stack: []vrState{{mode: mTopLevel}, {mode: mDocument}}, frame: 1}). invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue}) - _, _, goterr := vr.ReadValueBytes(nil) + _, _, goterr := vr.readValueBytes(nil) if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) { t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr) } diff --git a/bson/value_writer.go b/bson/value_writer.go index 4ae756d2164..fa04f67af3f 100644 --- a/bson/value_writer.go +++ b/bson/value_writer.go @@ -18,7 +18,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var _ ValueWriter = (*valueWriter)(nil) +var _ ValueWriter = &valueWriter{} var vwPool = sync.Pool{ New: func() interface{} { @@ -260,7 +260,7 @@ func (vw *valueWriter) writeElementHeader(t Type, destination mode, callerName s return nil } -func (vw *valueWriter) WriteValueBytes(t Type, b []byte) error { +func (vw *valueWriter) writeValueBytes(t Type, b []byte) error { if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil { return err } diff --git a/bson/value_writer_test.go b/bson/value_writer_test.go index 99e3b7d1dcf..0db2b040ff4 100644 --- a/bson/value_writer_test.go +++ b/bson/value_writer_test.go @@ -317,7 +317,7 @@ func TestValueWriter(t *testing.T) { vw := newValueWriterFromSlice(nil) want := TransitionError{current: mTopLevel, destination: mode(0), name: "WriteValueBytes", modes: []mode{mElement, mValue}, action: "write"} - got := vw.WriteValueBytes(TypeEmbeddedDocument, nil) + got := vw.writeValueBytes(TypeEmbeddedDocument, nil) if !assert.CompareErrors(got, want) { t.Errorf("Did not received expected error. got %v; want %v", got, want) } @@ -338,7 +338,7 @@ func TestValueWriter(t *testing.T) { noerr(t, err) _, err = vw.WriteDocumentElement("foo") noerr(t, err) - err = vw.WriteValueBytes(TypeEmbeddedDocument, doc) + err = vw.writeValueBytes(TypeEmbeddedDocument, doc) noerr(t, err) err = vw.WriteDocumentEnd() noerr(t, err) diff --git a/bson/writer.go b/bson/writer.go index 08ff44e466c..380a0c4cb3a 100644 --- a/bson/writer.go +++ b/bson/writer.go @@ -58,13 +58,11 @@ type ValueWriterFlusher interface { Flush() error } -// BytesWriter is the interface used to write BSON bytes to a ValueWriter. +// bytesWriter is the interface used to write BSON bytes to a ValueWriter. // This interface is meant to be a superset of ValueWriter, so that types that // implement ValueWriter may also implement this interface. -// -// Deprecated: BytesWriter will not be supported in Go Driver 2.0. -type BytesWriter interface { - WriteValueBytes(t Type, b []byte) error +type bytesWriter interface { + writeValueBytes(t Type, b []byte) error } // SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer. diff --git a/etc/run-atlas-test.sh b/etc/run-atlas-test.sh index 140b8734feb..5ddfcba78cf 100644 --- a/etc/run-atlas-test.sh +++ b/etc/run-atlas-test.sh @@ -7,5 +7,5 @@ set +x # Get the atlas secrets. . ${DRIVERS_TOOLS}/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect -echo "Running cmd/testatlas/main.go" -go run ./internal/cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite +echo "Running cmd/testatlas" +go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/internal/cmd/testatlas -args "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite diff --git a/internal/cmd/compilecheck/go.mod b/internal/cmd/compilecheck/go.mod index 444df5e2ee8..3b2992092c8 100644 --- a/internal/cmd/compilecheck/go.mod +++ b/internal/cmd/compilecheck/go.mod @@ -10,13 +10,13 @@ replace go.mongodb.org/mongo-driver => ../../../ require go.mongodb.org/mongo-driver v1.11.7 require ( - github.com/golang/snappy v0.0.1 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/klauspost/compress v1.13.6 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/crypto v0.22.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.14.0 // indirect ) diff --git a/internal/cmd/compilecheck/go.sum b/internal/cmd/compilecheck/go.sum index 83cc061005a..18d42be5dd7 100644 --- a/internal/cmd/compilecheck/go.sum +++ b/internal/cmd/compilecheck/go.sum @@ -1,6 +1,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= @@ -15,15 +15,16 @@ github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7Jul github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -35,8 +36,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/cmd/testatlas/main.go b/internal/cmd/testatlas/atlas_test.go similarity index 82% rename from internal/cmd/testatlas/main.go rename to internal/cmd/testatlas/atlas_test.go index 1bf2f8faff6..8637511f3cd 100644 --- a/internal/cmd/testatlas/main.go +++ b/internal/cmd/testatlas/atlas_test.go @@ -11,6 +11,8 @@ import ( "errors" "flag" "fmt" + "os" + "testing" "time" "go.mongodb.org/mongo-driver/bson" @@ -19,15 +21,19 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -func main() { +func TestMain(m *testing.M) { flag.Parse() + os.Exit(m.Run()) +} + +func TestAtlas(t *testing.T) { uris := flag.Args() ctx := context.Background() - fmt.Printf("Running atlas tests for %d uris\n", len(uris)) + t.Logf("Running atlas tests for %d uris\n", len(uris)) for idx, uri := range uris { - fmt.Printf("Running test %d\n", idx) + t.Logf("Running test %d\n", idx) // Set a low server selection timeout so we fail fast if there are errors. clientOpts := options.Client(). @@ -36,18 +42,18 @@ func main() { // Run basic connectivity test. if err := runTest(ctx, clientOpts); err != nil { - panic(fmt.Sprintf("error running test with TLS at index %d: %v", idx, err)) + t.Fatalf("error running test with TLS at index %d: %v", idx, err) } // Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is // disabled. clientOpts.TLSConfig.InsecureSkipVerify = true if err := runTest(ctx, clientOpts); err != nil { - panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err)) + t.Fatalf("error running test with tlsInsecure at index %d: %v", idx, err) } } - fmt.Println("Finished!") + t.Logf("Finished!") } func runTest(ctx context.Context, clientOpts *options.ClientOptions) error { diff --git a/internal/csot/csot.go b/internal/csot/csot.go index 678252c51a7..1e7b1901ea3 100644 --- a/internal/csot/csot.go +++ b/internal/csot/csot.go @@ -11,26 +11,74 @@ import ( "time" ) -type timeoutKey struct{} +type clientLevel struct{} -// MakeTimeoutContext returns a new context with Client-Side Operation Timeout (CSOT) feature-gated behavior -// and a Timeout set to the passed in Duration. Setting a Timeout on a single operation is not supported in -// public API. -// -// TODO(GODRIVER-2348) We may be able to remove this function once CSOT feature-gated behavior becomes the -// TODO default behavior. -func MakeTimeoutContext(ctx context.Context, to time.Duration) (context.Context, context.CancelFunc) { - // Only use the passed in Duration as a timeout on the Context if it - // is non-zero. - cancelFunc := func() {} - if to != 0 { - ctx, cancelFunc = context.WithTimeout(ctx, to) +func isClientLevel(ctx context.Context) bool { + val := ctx.Value(clientLevel{}) + if val == nil { + return false } - return context.WithValue(ctx, timeoutKey{}, true), cancelFunc + + return val.(bool) } +// IsTimeoutContext checks if the provided context has been assigned a deadline +// or has unlimited retries. func IsTimeoutContext(ctx context.Context) bool { - return ctx.Value(timeoutKey{}) != nil + _, ok := ctx.Deadline() + + return ok || isClientLevel(ctx) +} + +// WithTimeout will set the given timeout on the context, if no deadline has +// already been set. +// +// This function assumes that the timeout field is static, given that the +// timeout should be sourced from the client. Therefore, once a timeout function +// parameter has been applied to the context, it will remain for the lifetime +// of the context. +func WithTimeout(parent context.Context, timeout *time.Duration) (context.Context, context.CancelFunc) { + cancel := func() {} + + if timeout == nil || IsTimeoutContext(parent) { + // In the following conditions, do nothing: + // 1. The parent already has a deadline + // 2. The parent does not have a deadline, but a client-level timeout has + // been applied. + // 3. The parent does not have a deadline, there is not client-level + // timeout, and the timeout parameter DNE. + return parent, cancel + } + + // If a client-level timeout has not been applied, then apply it. + parent = context.WithValue(parent, clientLevel{}, true) + + dur := *timeout + + if dur == 0 { + // If the parent does not have a deadline and the timeout is zero, then + // do nothing. + return parent, cancel + } + + // If the parent does not have a dealine and the timeout is non-zero, then + // apply the timeout. + return context.WithTimeout(parent, dur) +} + +// WithServerSelectionTimeout creates a context with a timeout that is the +// minimum of serverSelectionTimeoutMS and context deadline. The usage of +// non-positive values for serverSelectionTimeoutMS are an anti-pattern and are +// not considered in this calculation. +func WithServerSelectionTimeout( + parent context.Context, + serverSelectionTimeout time.Duration, +) (context.Context, context.CancelFunc) { + if serverSelectionTimeout <= 0 { + return parent, func() {} + } + + return context.WithTimeout(parent, serverSelectionTimeout) } // ZeroRTTMonitor implements the RTTMonitor interface and is used internally for testing. It returns 0 for all diff --git a/internal/csot/csot_test.go b/internal/csot/csot_test.go new file mode 100644 index 00000000000..5b79f6994ad --- /dev/null +++ b/internal/csot/csot_test.go @@ -0,0 +1,249 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package csot + +import ( + "context" + "testing" + "time" + + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/ptrutil" +) + +func newTestContext(t *testing.T, timeout time.Duration) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + return ctx +} + +func TestWithServerSelectionTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + parent context.Context + serverSelectionTimeout time.Duration + wantTimeout time.Duration + wantOk bool + }{ + { + name: "no context deadine and ssto is zero", + parent: context.Background(), + serverSelectionTimeout: 0, + wantTimeout: 0, + wantOk: false, + }, + { + name: "no context deadline and ssto is positive", + parent: context.Background(), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "no context deadline and ssto is negative", + parent: context.Background(), + serverSelectionTimeout: -1, + wantTimeout: 0, + wantOk: false, + }, + { + name: "context deadline is zero and ssto is positive", + parent: newTestContext(t, 0), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is zero and ssto is negative", + parent: newTestContext(t, 0), + serverSelectionTimeout: -1, + wantTimeout: 0, + wantOk: true, + }, + { + name: "context deadline is negative and ssto is zero", + parent: newTestContext(t, -1), + serverSelectionTimeout: 0, + wantTimeout: -1, + wantOk: true, + }, + { + name: "context deadline is negative and ssto is positive", + parent: newTestContext(t, -1), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is negative and ssto is negative", + parent: newTestContext(t, -1), + serverSelectionTimeout: -1, + wantTimeout: -1, + wantOk: true, + }, + { + name: "context deadline is positive and ssto is zero", + parent: newTestContext(t, 1), + serverSelectionTimeout: 0, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is positive and equal to ssto", + parent: newTestContext(t, 1), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is positive lt ssto", + parent: newTestContext(t, 1), + serverSelectionTimeout: 2, + wantTimeout: 2, + wantOk: true, + }, + { + name: "context deadline is positive gt ssto", + parent: newTestContext(t, 2), + serverSelectionTimeout: 1, + wantTimeout: 2, + wantOk: true, + }, + { + name: "context deadline is positive and ssto is negative", + parent: newTestContext(t, -1), + serverSelectionTimeout: -1, + wantTimeout: 1, + wantOk: true, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := WithServerSelectionTimeout(test.parent, test.serverSelectionTimeout) + t.Cleanup(cancel) + + deadline, gotOk := ctx.Deadline() + assert.Equal(t, test.wantOk, gotOk) + + if gotOk { + delta := time.Until(deadline) - test.wantTimeout + tolerance := 10 * time.Millisecond + + assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance) + assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance) + } + }) + } +} + +func TestWithTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + parent context.Context + timeout *time.Duration + wantTimeout time.Duration + wantDeadline bool + wantValues []interface{} + }{ + { + name: "deadline set with non-zero timeout", + parent: newTestContext(t, 1), + timeout: ptrutil.Ptr(time.Duration(2)), + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline set with zero timeout", + parent: newTestContext(t, 1), + timeout: ptrutil.Ptr(time.Duration(0)), + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline set with nil timeout", + parent: newTestContext(t, 1), + timeout: nil, + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline unset with non-zero timeout", + parent: context.Background(), + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline unset with zero timeout", + parent: context.Background(), + timeout: ptrutil.Ptr(time.Duration(0)), + wantTimeout: 0, + wantDeadline: false, + wantValues: []interface{}{clientLevel{}}, + }, + { + name: "deadline unset with nil timeout", + parent: context.Background(), + timeout: nil, + wantTimeout: 0, + wantDeadline: false, + wantValues: []interface{}{}, + }, + { + // If "clientLevel" has been set, but a new timeout is applied + // to the context, then the constructed context should retain the old + // timeout. To simplify the code, we assume the first timeout is static. + name: "deadline unset with non-zero timeout at clientLevel", + parent: context.WithValue(context.Background(), clientLevel{}, true), + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 0, + wantDeadline: false, + wantValues: []interface{}{}, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := WithTimeout(test.parent, test.timeout) + t.Cleanup(cancel) + + deadline, gotDeadline := ctx.Deadline() + assert.Equal(t, test.wantDeadline, gotDeadline) + + if gotDeadline { + delta := time.Until(deadline) - test.wantTimeout + tolerance := 10 * time.Millisecond + + assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance) + assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance) + } + + for _, wantValue := range test.wantValues { + assert.NotNil(t, ctx.Value(wantValue), "expected context to have value %v", wantValue) + } + }) + } + +} diff --git a/internal/docexamples/examples.go b/internal/docexamples/examples.go index 2f95bce65f2..71064947a95 100644 --- a/internal/docexamples/examples.go +++ b/internal/docexamples/examples.go @@ -1978,7 +1978,6 @@ func WithTransactionExample(ctx context.Context) error { // Prereq: Create collections. wcMajority := writeconcern.Majority() - wcMajority.WTimeout = 1 * time.Second wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) fooColl := client.Database("mydb1").Collection("foo", wcMajorityCollectionOpts) barColl := client.Database("mydb1").Collection("bar", wcMajorityCollectionOpts) @@ -2559,7 +2558,6 @@ func CausalConsistencyExamples(client *mongo.Client) error { rc := readconcern.Majority() wc := writeconcern.Majority() - wc.WTimeout = 1000 // Use a causally-consistent session to run some operations opts := options.Session().SetDefaultReadConcern(rc).SetDefaultWriteConcern(wc) session1, err := client.StartSession(opts) diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index d5c829ea434..af6f4af7b66 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -53,16 +53,6 @@ const ( maxBsonObjSize = 16777216 // max bytes in BSON object ) -func containsSubstring(possibleSubstrings []string, str string) bool { - for _, possibleSubstring := range possibleSubstrings { - if strings.Contains(str, possibleSubstring) { - return true - } - } - - return false -} - func TestClientSideEncryptionProse(t *testing.T) { t.Parallel() @@ -150,7 +140,6 @@ func TestClientSideEncryptionProse(t *testing.T) { // Insert the copied key document into keyvault.datakeys with majority write concern. wcMajority := writeconcern.Majority() - wcMajority.WTimeout = 1 * time.Second wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) wcmColl := cse.kvClient.Database(kvDatabase).Collection(dkCollection, wcMajorityCollectionOpts) _, err = wcmColl.InsertOne(context.Background(), alteredKeydoc) @@ -1001,7 +990,7 @@ func TestClientSideEncryptionProse(t *testing.T) { if len(tc.errorSubstring) > 0 { assert.NotNil(mt, err, "expected error, got nil") - assert.True(t, containsSubstring(tc.errorSubstring, err.Error()), + assert.True(t, containsPattern(tc.errorSubstring, err.Error()), "expected tc.errorSubstring=%v to contain %v, but it didn't", tc.errorSubstring, err.Error()) return @@ -1031,7 +1020,7 @@ func TestClientSideEncryptionProse(t *testing.T) { _, err = invalidClientEncryption.CreateDataKey(context.Background(), tc.provider, invalidKeyOpts) assert.NotNil(mt, err, "expected CreateDataKey error, got nil") - assert.True(t, containsSubstring(tc.invalidClientEncryptionErrorSubstring, err.Error()), + assert.True(t, containsPattern(tc.invalidClientEncryptionErrorSubstring, err.Error()), "expected tc.invalidClientEncryptionErrorSubstring=%v to contain %v, but it didn't", tc.invalidClientEncryptionErrorSubstring, err.Error()) }) @@ -1635,7 +1624,7 @@ func TestClientSideEncryptionProse(t *testing.T) { "x509: certificate is not authorized to sign other certificates", // All others } - assert.True(t, containsSubstring(possibleErrors, err.Error()), + assert.True(t, containsPattern(possibleErrors, err.Error()), "expected possibleErrors=%v to contain %v, but it didn't", possibleErrors, err.Error()) @@ -1896,7 +1885,6 @@ func TestClientSideEncryptionProse(t *testing.T) { } wcMajority := writeconcern.Majority() - wcMajority.WTimeout = 1 * time.Second wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) wcmColl := cse.kvClient.Database(kvDatabase).Collection(dkCollection, wcMajorityCollectionOpts) _, err = wcmColl.Indexes().CreateOne(context.Background(), keyVaultIndex) @@ -2254,7 +2242,7 @@ func TestClientSideEncryptionProse(t *testing.T) { "Client.Timeout or context cancellation while reading body", // > 1.20 on all OS } - assert.True(t, containsSubstring(possibleErrors, err.Error()), + assert.True(t, containsPattern(possibleErrors, err.Error()), "expected possibleErrors=%v to contain %v, but it didn't", possibleErrors, err.Error()) }) diff --git a/internal/integration/client_side_encryption_test.go b/internal/integration/client_side_encryption_test.go index ed349f65e5c..e59dfa63db9 100644 --- a/internal/integration/client_side_encryption_test.go +++ b/internal/integration/client_side_encryption_test.go @@ -722,7 +722,7 @@ func TestFLEIndexView(t *testing.T) { cc.numEncryptCalls = 0 // Reset Encrypt calls from createIndexes - _, err := coll.Indexes().DropOne(context.Background(), "a_1") + err := coll.Indexes().DropOne(context.Background(), "a_1") assert.NoError(mt, err, "error dropping one index: %v", err) assert.Equal(mt, cc.numEncryptCalls, 1, "expected 1 call to Encrypt, got %v", cc.numEncryptCalls) diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index ceae58ac81a..822b517cd64 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -324,7 +324,7 @@ func TestClient(t *testing.T) { // apply the correct URI. invalidClientOpts := options.Client(). SetServerSelectionTimeout(100 * time.Millisecond).SetHosts([]string{"invalid:123"}). - SetConnectTimeout(500 * time.Millisecond).SetSocketTimeout(500 * time.Millisecond) + SetConnectTimeout(500 * time.Millisecond).SetTimeout(500 * time.Millisecond) integtest.AddTestServerAPIVersion(invalidClientOpts) client, err := mongo.Connect(invalidClientOpts) assert.Nil(mt, err, "Connect error: %v", err) @@ -517,31 +517,23 @@ func TestClient(t *testing.T) { // Assert that the minimum RTT is eventually >250ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's minimum RTTs to be >250ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().Min() <= 250*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's minimum RTTs to be >250ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().Min() <= 250*time.Millisecond { + return false // the tick should wait for 100ms in this case } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the minimum RTT is eventually >250ms") }) // Test that if the minimum RTT is greater than the remaining timeout for an operation, the @@ -565,31 +557,23 @@ func TestClient(t *testing.T) { // Assert that the minimum RTT is eventually >250ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's minimum RTTs to be >250ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().Min() <= 250*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's minimum RTTs to be >250ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().Min() <= 250*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the minimum RTT is eventually >250ms") // Once we've waited for the minimum RTT for the single server to be >250ms, run a bunch of // Ping operations with a timeout of 250ms and expect that they return errors. diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index f520e16c89e..dd1eb51e72a 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -11,7 +11,6 @@ import ( "errors" "strings" "testing" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" @@ -36,8 +35,7 @@ var ( // for various operations. It includes a timeout because legacy servers will wait for all W nodes to respond, // causing tests to hang. impossibleWc = &writeconcern.WriteConcern{ - W: 30, - WTimeout: time.Second, + W: 30, } ) @@ -862,7 +860,7 @@ func TestCollection(t *testing.T) { count int64 }{ {"no options", nil, 5}, - {"options", options.EstimatedDocumentCount().SetMaxTime(1 * time.Second), 5}, + {"options", options.EstimatedDocumentCount().SetComment("1"), 5}, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { @@ -884,7 +882,7 @@ func TestCollection(t *testing.T) { }{ {"no options", bson.D{}, nil, all}, {"filter", bson.D{{"x", bson.D{{"$gt", 2}}}}, nil, all[2:]}, - {"options", bson.D{}, options.Distinct().SetMaxTime(5000000000), all}, + {"options", bson.D{}, options.Distinct().SetComment("1"), all}, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { @@ -1166,7 +1164,6 @@ func TestCollection(t *testing.T) { SetComment(expectedComment). SetHint(indexName). SetMax(bson.D{{"x", int32(5)}}). - SetMaxTime(1 * time.Second). SetMin(bson.D{{"x", int32(0)}}). SetProjection(bson.D{{"x", int32(1)}}). SetReturnKey(false). @@ -1188,7 +1185,6 @@ func TestCollection(t *testing.T) { AppendString("comment", expectedComment). AppendString("hint", indexName). StartDocument("max").AppendInt32("x", 5).FinishDocument(). - AppendInt32("maxTimeMS", 1000). StartDocument("min").AppendInt32("x", 0).FinishDocument(). StartDocument("projection").AppendInt32("x", 1).FinishDocument(). AppendBoolean("returnKey", false). diff --git a/internal/integration/crud_helpers_test.go b/internal/integration/crud_helpers_test.go index 355f934add7..fdf13cfca98 100644 --- a/internal/integration/crud_helpers_test.go +++ b/internal/integration/crud_helpers_test.go @@ -174,8 +174,6 @@ func executeAggregate(mt *mtest.T, agg aggregator, sess *mongo.Session, args bso opts.SetBatchSize(val.Int32()) case "collation": opts.SetCollation(createCollation(mt, val.Document())) - case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) case "allowDiskUse": opts.SetAllowDiskUse(val.Boolean()) case "session": @@ -348,8 +346,6 @@ func setFindModifiers(modifiersDoc bson.Raw, opts *options.FindOptions) { opts.SetHint(val.Document()) case "$max": opts.SetMax(val.Document()) - case "$maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) case "$min": opts.SetMin(val.Document()) case "$returnKey": @@ -1293,7 +1289,7 @@ func executeCreateIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (string return mt.Coll.Indexes().CreateOne(context.Background(), model) } -func executeDropIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (bson.Raw, error) { +func executeDropIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() var name string @@ -1311,14 +1307,11 @@ func executeDropIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (bson.Raw } if sess != nil { - var res bson.Raw - err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { - var indexErr error - res, indexErr = mt.Coll.Indexes().DropOne(sc, name) - return indexErr + return mongo.WithSession(context.Background(), sess, func(sc context.Context) error { + return mt.Coll.Indexes().DropOne(sc, name) }) - return res, err } + return mt.Coll.Indexes().DropOne(context.Background(), name) } diff --git a/internal/integration/csot_prose_test.go b/internal/integration/csot_prose_test.go index c8de1f68adb..9f012893989 100644 --- a/internal/integration/csot_prose_test.go +++ b/internal/integration/csot_prose_test.go @@ -89,13 +89,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored if timeoutMS is not set", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=100&serverSelectionTimeoutMS=200") @@ -103,13 +108,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=200&serverSelectionTimeoutMS=100") @@ -117,13 +127,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored for server selection if it's lower than timeoutMS", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=100") @@ -131,13 +146,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored for server selection if timeoutMS=0", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) }) } diff --git a/internal/integration/errors_test.go b/internal/integration/errors_test.go index 8c6c0fb8122..b3a4094c15a 100644 --- a/internal/integration/errors_test.go +++ b/internal/integration/errors_test.go @@ -15,6 +15,7 @@ import ( "fmt" "io" "net" + "regexp" "testing" "time" @@ -46,6 +47,17 @@ func (n netErr) Temporary() bool { var _ net.Error = (*netErr)(nil) +func containsPattern(patterns []string, str string) bool { + for _, pattern := range patterns { + re := regexp.MustCompile(pattern) + if re.MatchString(str) { + return true + } + } + + return false +} + func TestErrors(t *testing.T) { mt := mtest.New(t, noClientOpts) @@ -96,39 +108,22 @@ func TestErrors(t *testing.T) { } timeoutCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - _, err = mt.Coll.Find(timeoutCtx, filter) - evt := mt.GetStartedEvent() - assert.Equal(mt, "find", evt.CommandName, "expected command 'find', got %q", evt.CommandName) - assert.True(mt, errors.Is(err, context.DeadlineExceeded), - "errors.Is failure: expected error %v to be %v", err, context.DeadlineExceeded) - }) - - mt.Run("socketTimeoutMS timeouts return network errors", func(mt *mtest.T) { - _, err := mt.Coll.InsertOne(context.Background(), bson.D{{"x", 1}}) - assert.Nil(mt, err, "InsertOne error: %v", err) + _, err = mt.Coll.Find(timeoutCtx, filter) - // Reset the test client to have a 100ms socket timeout. We do this here rather than passing it in as a - // test option using mt.RunOpts because that could cause the collection creation or InsertOne to fail. - resetClientOpts := options.Client(). - SetSocketTimeout(100 * time.Millisecond) - mt.ResetClient(resetClientOpts) + assert.Error(mt, err) - mt.ClearEvents() - filter := bson.M{ - "$where": "function() { sleep(1000); return false; }", + errPatterns := []string{ + context.DeadlineExceeded.Error(), + `^\(MaxTimeMSExpired\) Executor error during find command.*:: caused by :: operation exceeded time limit$`, } - _, err = mt.Coll.Find(context.Background(), filter) + + assert.True(t, containsPattern(errPatterns, err.Error()), + "expected possibleErrors=%v to contain %v, but it didn't", + errPatterns, err.Error()) evt := mt.GetStartedEvent() assert.Equal(mt, "find", evt.CommandName, "expected command 'find', got %q", evt.CommandName) - - assert.False(mt, errors.Is(err, context.DeadlineExceeded), - "errors.Is failure: expected error %v to not be %v", err, context.DeadlineExceeded) - var netErr net.Error - ok := errors.As(err, &netErr) - assert.True(mt, ok, "errors.As failure: expected error %v to be a net.Error", err) - assert.True(mt, netErr.Timeout(), "expected error %v to be a network timeout", err) }) }) mt.Run("ServerError", func(mt *mtest.T) { @@ -505,26 +500,124 @@ func TestErrors(t *testing.T) { err error result bool }{ - {"context timeout", mongo.CommandError{ - 100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}, true}, - {"deadline would be exceeded", mongo.CommandError{ - 100, "", []string{"other"}, "blah", driver.ErrDeadlineWouldBeExceeded, nil}, true}, - {"server selection timeout", mongo.CommandError{ - 100, "", []string{"other"}, "blah", topology.ErrServerSelectionTimeout, nil}, true}, - {"wait queue timeout", mongo.CommandError{ - 100, "", []string{"other"}, "blah", topology.WaitQueueTimeoutError{}, nil}, true}, - {"ServerError NetworkTimeoutError", mongo.CommandError{ - 100, "", []string{"NetworkTimeoutError"}, "blah", nil, nil}, true}, - {"ServerError ExceededTimeLimitError", mongo.CommandError{ - 100, "", []string{"ExceededTimeLimitError"}, "blah", nil, nil}, true}, - {"ServerError false", mongo.CommandError{ - 100, "", []string{"other"}, "blah", nil, nil}, false}, - {"net error true", mongo.CommandError{ - 100, "", []string{"other"}, "blah", netErr{true}, nil}, true}, - {"net error false", netErr{false}, false}, - {"wrapped error", fmt.Errorf("%w", mongo.CommandError{ - 100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}), true}, - {"other error", errors.New("foo"), false}, + { + name: "context timeout", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: context.DeadlineExceeded, + Raw: nil, + }, + result: true, + }, + { + name: "deadline would be exceeded", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: driver.ErrDeadlineWouldBeExceeded, + Raw: nil, + }, + result: true, + }, + { + name: "server selection timeout", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: context.DeadlineExceeded, + Raw: nil, + }, + result: true, + }, + { + name: "wait queue timeout", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: topology.WaitQueueTimeoutError{}, + Raw: nil, + }, + result: true, + }, + { + name: "ServerError NetworkTimeoutError", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"NetworkTimeoutError"}, + Name: "blah", + Wrapped: nil, + Raw: nil, + }, + result: true, + }, + { + name: "ServerError ExceededTimeLimitError", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"ExceededTimeLimitError"}, + Name: "blah", + Wrapped: nil, + Raw: nil, + }, + result: true, + }, + { + name: "ServerError false", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: nil, + Raw: nil, + }, + result: false, + }, + { + name: "net error true", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: netErr{true}, + Raw: nil, + }, + result: true, + }, + { + name: "net error false", + err: netErr{false}, + result: false, + }, + { + name: "wrapped error", + err: fmt.Errorf("%w", mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: context.DeadlineExceeded, + Raw: nil, + }), + result: true, + }, + { + name: "other error", + err: errors.New("foo"), + result: false, + }, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { diff --git a/internal/integration/index_view_test.go b/internal/integration/index_view_test.go index 5b4a46e42f9..192dd9e422e 100644 --- a/internal/integration/index_view_test.go +++ b/internal/integration/index_view_test.go @@ -10,7 +10,6 @@ import ( "context" "errors" "testing" - "time" "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/bson" @@ -554,16 +553,20 @@ func TestIndexView(t *testing.T) { assert.True(mt, cmp.Equal(specs, expectedSpecs), "expected specifications to match: %v", cmp.Diff(specs, expectedSpecs)) }) mt.RunOpts("options passed to listIndexes", mtest.NewOptions().MinServerVersion("3.0"), func(mt *mtest.T) { - opts := options.ListIndexes().SetMaxTime(100 * time.Millisecond) + opts := options.ListIndexes().SetBatchSize(1) _, err := mt.Coll.Indexes().ListSpecifications(context.Background(), opts) assert.Nil(mt, err, "ListSpecifications error: %v", err) evt := mt.GetStartedEvent() assert.Equal(mt, evt.CommandName, "listIndexes", "expected %q command to be sent, got %q", "listIndexes", evt.CommandName) - maxTimeMS, ok := evt.Command.Lookup("maxTimeMS").Int64OK() - assert.True(mt, ok, "expected command %v to contain %q field", evt.Command, "maxTimeMS") - assert.Equal(mt, int64(100), maxTimeMS, "expected maxTimeMS value to be 100, got %d", maxTimeMS) + + cursorDoc, ok := evt.Command.Lookup("cursor").DocumentOK() + assert.True(mt, ok, "expected command: %v to contain a cursor document", evt.Command) + + batchSize, ok := cursorDoc.Lookup("batchSize").Int32OK() + assert.True(mt, ok, "expected command %v to contain %q field", evt.Command, "batchSize") + assert.Equal(mt, int32(1), batchSize, "expected batchSize value to be 1, got %d", batchSize) }) }) mt.Run("drop one", func(mt *mtest.T) { @@ -579,7 +582,7 @@ func TestIndexView(t *testing.T) { assert.Nil(mt, err, "CreateMany error: %v", err) assert.Equal(mt, 2, len(indexNames), "expected 2 index names, got %v", len(indexNames)) - _, err = iv.DropOne(context.Background(), indexNames[1]) + err = iv.DropOne(context.Background(), indexNames[1]) assert.Nil(mt, err, "DropOne error: %v", err) cursor, err := iv.List(context.Background()) diff --git a/internal/integration/json_helpers_test.go b/internal/integration/json_helpers_test.go index 24877da1590..194c3164136 100644 --- a/internal/integration/json_helpers_test.go +++ b/internal/integration/json_helpers_test.go @@ -111,9 +111,6 @@ func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions { case "serverSelectionTimeoutMS": sst := convertValueToMilliseconds(t, opt) clientOpts.SetServerSelectionTimeout(sst) - case "socketTimeoutMS": - st := convertValueToMilliseconds(t, opt) - clientOpts.SetSocketTimeout(st) case "minPoolSize": clientOpts.SetMinPoolSize(uint64(opt.AsInt64())) case "maxPoolSize": @@ -301,9 +298,6 @@ func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions { if txnOpts.WriteConcern != nil { sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern) } - if txnOpts.MaxCommitTime != nil { - sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime) - } default: t.Fatalf("unrecognized session option: %v", name) } @@ -378,8 +372,7 @@ func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionO case "readConcern": txnOpts.SetReadConcern(createReadConcern(opt)) case "maxCommitTimeMS": - t := time.Duration(opt.Int32()) * time.Millisecond - txnOpts.SetMaxCommitTime(&t) + t.Skip("GODRIVER-2348: maxCommitTimeMS is deprecated") default: t.Fatalf("unrecognized transaction option: %v", opt) } @@ -406,9 +399,6 @@ func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConc val := elem.Value() switch key { - case "wtimeout": - wtimeout := convertValueToMilliseconds(t, val) - wc.WTimeout = wtimeout case "j": j := val.Boolean() wc.Journal = &j diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 785044a42f8..affa4233df2 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -14,7 +14,6 @@ import ( "sync" "sync/atomic" "testing" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" @@ -522,7 +521,6 @@ func (t *T) ClearCollections() { // re-instantiating the collection with a majority write concern before dropping. collname := coll.created.Name() wcm := writeconcern.Majority() - wcm.WTimeout = 1 * time.Second wccoll := t.DB.Collection(collname, options.Collection().SetWriteConcern(wcm)) _ = wccoll.Drop(context.Background()) diff --git a/internal/integration/mtest/opmsg_deployment.go b/internal/integration/mtest/opmsg_deployment.go index 6a0a1021c18..bcc10275e6e 100644 --- a/internal/integration/mtest/opmsg_deployment.go +++ b/internal/integration/mtest/opmsg_deployment.go @@ -9,6 +9,7 @@ package mtest import ( "context" "errors" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csot" @@ -133,6 +134,12 @@ func (md *mockDeployment) SelectServer(context.Context, description.ServerSelect return md, nil } +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for mock deployments. +func (*mockDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} + // Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (md *mockDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle diff --git a/internal/integration/sdam_error_handling_test.go b/internal/integration/sdam_error_handling_test.go index 8b960694766..5f0b768cefb 100644 --- a/internal/integration/sdam_error_handling_test.go +++ b/internal/integration/sdam_error_handling_test.go @@ -75,9 +75,10 @@ func TestSDAMErrorHandling(t *testing.T) { mt.ResetClient(baseClientOpts(). SetAppName(appName). SetPoolMonitor(tpm.PoolMonitor). - // Set a 100ms socket timeout so that the saslContinue delay of 150ms causes a - // timeout during socket read (i.e. a timeout not caused by the InsertOne context). - SetSocketTimeout(100 * time.Millisecond)) + // Set a 100ms connect timeout so that the saslContinue delay of 150ms + // causes a timeout during a heartbeat (i.e. a timeout not caused by + // the InsertOne context). + SetConnectTimeout(100 * time.Millisecond)) // Use context.Background() so that the new connection will not time out due to an // operation-scoped timeout. @@ -85,23 +86,13 @@ func TestSDAMErrorHandling(t *testing.T) { assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) + // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) { @@ -131,22 +122,11 @@ func TestSDAMErrorHandling(t *testing.T) { SetMinPoolSize(5)) // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) mt.Run("foreground", func(mt *mtest.T) { @@ -175,22 +155,11 @@ func TestSDAMErrorHandling(t *testing.T) { assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err) // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) }) }) diff --git a/internal/integration/sdam_prose_test.go b/internal/integration/sdam_prose_test.go index 5aa33589056..d0df9360806 100644 --- a/internal/integration/sdam_prose_test.go +++ b/internal/integration/sdam_prose_test.go @@ -124,28 +124,23 @@ func TestSDAMProse(t *testing.T) { AppName: "streamingRttTest", }, }) - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: + callback := func() bool { + // We don't know which server received the failpoint command, so we wait until any of the server + // RTTs cross the threshold. + for _, serverDesc := range testTopology.Description().Servers { + if serverDesc.AverageRTT > 250*time.Millisecond { + return true } - - // We don't know which server received the failpoint command, so we wait until any of the server - // RTTs cross the threshold. - for _, serverDesc := range testTopology.Description().Servers { - if serverDesc.AverageRTT > 250*time.Millisecond { - return - } - } - - // The next update will be in ~500ms. - time.Sleep(500 * time.Millisecond) } + + // The next update will be in ~500ms. + return false } - assert.Soon(t, callback, defaultCallbackTimeout) + assert.Eventually(t, + callback, + defaultCallbackTimeout, + 500*time.Millisecond, + "expected average rtt heartbeats at least within every 500 ms period") }) }) @@ -210,6 +205,7 @@ func TestServerHeartbeatStartedEvent(t *testing.T) { server := topology.NewServer( address, bson.NewObjectID(), + 1*time.Second, topology.WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return &event.ServerMonitor{ ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) { diff --git a/internal/integration/unified/client_entity.go b/internal/integration/unified/client_entity.go index 2d4c87b94b4..200b0130b49 100644 --- a/internal/integration/unified/client_entity.go +++ b/internal/integration/unified/client_entity.go @@ -612,7 +612,12 @@ func setClientOptionsFromURIOptions(clientOpts *options.ClientOptions, uriOpts b case "retrywrites": clientOpts.SetRetryWrites(value.(bool)) case "sockettimeoutms": - clientOpts.SetSocketTimeout(time.Duration(value.(int32)) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing socketTimeoutMS (a legacy timeout option + // that we have removed as of v2), then a CSOT analogue exists. Once we + // have ensured an analogue exists, extend "skippedTestDescriptions" to + // avoid this error. + return newSkipTestError("the socketTimeoutMS client option is not supported") case "w": wc.W = value wcSet = true diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index d2110139200..c1279161ff5 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -67,8 +67,6 @@ func executeAggregate(ctx context.Context, operation *operation) (*operationResu return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) - case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) case "maxAwaitTimeMS": opts.SetMaxAwaitTime(time.Duration(val.Int32()) * time.Millisecond) case "pipeline": @@ -194,7 +192,12 @@ func executeCountDocuments(ctx context.Context, operation *operation) (*operatio case "limit": opts.SetLimit(val.Int64()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "skip": opts.SetSkip(int64(val.Int32())) default: @@ -523,7 +526,12 @@ func executeDistinct(ctx context.Context, operation *operation) (*operationResul case "filter": filter = val.Document() case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized distinct option %q", key) } @@ -566,7 +574,12 @@ func executeDropIndex(ctx context.Context, operation *operation) (*operationResu case "name": name = val.StringValue() case "maxTimeMS": - dropIndexOpts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized dropIndex option %q", key) } @@ -575,8 +588,8 @@ func executeDropIndex(ctx context.Context, operation *operation) (*operationResu return nil, newMissingArgumentError("name") } - res, err := coll.Indexes().DropOne(ctx, name, dropIndexOpts) - return newDocumentResult(res, err), nil + err = coll.Indexes().DropOne(ctx, name, dropIndexOpts) + return newDocumentResult(nil, err), nil } func executeDropIndexes(ctx context.Context, operation *operation) (*operationResult, error) { @@ -589,11 +602,15 @@ func executeDropIndexes(ctx context.Context, operation *operation) (*operationRe elems, _ := operation.Arguments.Elements() for _, elem := range elems { key := elem.Key() - val := elem.Value() switch key { case "maxTimeMS": - dropIndexOpts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized dropIndexes option %q", key) } @@ -654,7 +671,12 @@ func executeEstimatedDocumentCount(ctx context.Context, operation *operation) (* case "comment": opts.SetComment(val) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized estimatedDocumentCount option %q", key) } @@ -731,7 +753,12 @@ func executeFindOne(ctx context.Context, operation *operation) (*operationResult } opts.SetHint(hint) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "sort": @@ -790,7 +817,12 @@ func executeFindOneAndDelete(ctx context.Context, operation *operation) (*operat } opts.SetHint(hint) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "sort": @@ -856,7 +888,12 @@ func executeFindOneAndReplace(ctx context.Context, operation *operation) (*opera case "let": opts.SetLet(val.Document()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "replacement": @@ -940,7 +977,12 @@ func executeFindOneAndUpdate(ctx context.Context, operation *operation) (*operat case "let": opts.SetLet(val.Document()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "returnDocument": @@ -1403,7 +1445,12 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, case "max": opts.SetMax(val.Document()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "min": opts.SetMin(val.Document()) case "noCursorTimeout": diff --git a/internal/integration/unified/common_options.go b/internal/integration/unified/common_options.go index 7c345863252..2b78466a9be 100644 --- a/internal/integration/unified/common_options.go +++ b/internal/integration/unified/common_options.go @@ -28,9 +28,8 @@ func (rc *readConcern) toReadConcernOption() *readconcern.ReadConcern { } type writeConcern struct { - Journal *bool `bson:"journal"` - W interface{} `bson:"w"` - WTimeoutMS *int32 `bson:"wtimeoutMS"` + Journal *bool `bson:"journal"` + W interface{} `bson:"w"` } func (wc *writeConcern) toWriteConcernOption() (*writeconcern.WriteConcern, error) { @@ -51,10 +50,6 @@ func (wc *writeConcern) toWriteConcernOption() (*writeconcern.WriteConcern, erro return nil, fmt.Errorf("invalid type for write concern 'w' field %T", wc.W) } } - if wc.WTimeoutMS != nil { - wTimeout := time.Duration(*wc.WTimeoutMS) * time.Millisecond - c.WTimeout = wTimeout - } return c, nil } diff --git a/internal/integration/unified/database_operation_execution.go b/internal/integration/unified/database_operation_execution.go index 269dd7e9295..156b2b29a54 100644 --- a/internal/integration/unified/database_operation_execution.go +++ b/internal/integration/unified/database_operation_execution.go @@ -284,7 +284,6 @@ func executeRunCursorCommand(ctx context.Context, operation *operation) (*operat batchSize int32 command bson.Raw comment bson.Raw - maxTime time.Duration ) opts := options.RunCmd() @@ -306,7 +305,12 @@ func executeRunCursorCommand(ctx context.Context, operation *operation) (*operat case "comment": comment = val.Document() case "maxTimeMS": - maxTime = time.Duration(val.AsInt64()) * time.Millisecond + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS database option is not supported") case "cursorTimeout": return nil, newSkipTestError("cursorTimeout not supported") case "timeoutMode": @@ -329,10 +333,6 @@ func executeRunCursorCommand(ctx context.Context, operation *operation) (*operat cursor.SetBatchSize(batchSize) } - if maxTime > 0 { - cursor.SetMaxTime(maxTime) - } - if len(comment) > 0 { cursor.SetComment(comment) } diff --git a/internal/integration/unified/gridfs_bucket_operation_execution.go b/internal/integration/unified/gridfs_bucket_operation_execution.go index d2ca0f5652c..512c5828421 100644 --- a/internal/integration/unified/gridfs_bucket_operation_execution.go +++ b/internal/integration/unified/gridfs_bucket_operation_execution.go @@ -13,7 +13,6 @@ import ( "fmt" "io" "math" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -39,7 +38,12 @@ func createBucketFindCursor(ctx context.Context, operation *operation) (*cursorR switch key { case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS gridfs option is not supported") case "filter": filter = val.Document() default: diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index 59aa36ae8c4..989e58673c0 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -93,7 +93,16 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat // Special handling for the "timeoutMS" field because it applies to (almost) all operations. if tms, ok := op.Arguments.Lookup("timeoutMS").Int32OK(); ok { timeout := time.Duration(tms) * time.Millisecond - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, timeout) + + // Note that a 0-timeout at the operation level is not actually possible + // in Go. This would result in an immediate "context deadline exceeded" + // error. + // + // To achieve an "infinite" case, users would have to rely on either (1) + // defining a 0 timeout at the client-level, or (2) use + // context.Background() at the operation-level. + newCtx, cancelFunc := csot.WithTimeout(ctx, &timeout) + // Redefine ctx to be the new timeout-derived context. ctx = newCtx // Cancel the timeout-derived context at the end of run to avoid a context leak. diff --git a/internal/integration/unified/session_options.go b/internal/integration/unified/session_options.go index 9882073b22c..c02865d9750 100644 --- a/internal/integration/unified/session_options.go +++ b/internal/integration/unified/session_options.go @@ -8,7 +8,6 @@ package unified import ( "fmt" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -24,11 +23,10 @@ var _ bson.Unmarshaler = (*transactionOptions)(nil) func (to *transactionOptions) UnmarshalBSON(data []byte) error { var temp struct { - RC *readConcern `bson:"readConcern"` - RP *ReadPreference `bson:"readPreference"` - WC *writeConcern `bson:"writeConcern"` - MaxCommitTimeMS *int64 `bson:"maxCommitTimeMS"` - Extra map[string]interface{} `bson:",inline"` + RC *readConcern `bson:"readConcern"` + RP *ReadPreference `bson:"readPreference"` + WC *writeConcern `bson:"writeConcern"` + Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { return fmt.Errorf("error unmarshalling to temporary transactionOptions object: %v", err) @@ -38,10 +36,6 @@ func (to *transactionOptions) UnmarshalBSON(data []byte) error { } to.TransactionOptions = options.Transaction() - if temp.MaxCommitTimeMS != nil { - mctms := time.Duration(*temp.MaxCommitTimeMS) * time.Millisecond - to.SetMaxCommitTime(&mctms) - } if rc := temp.RC; rc != nil { to.SetReadConcern(rc.toReadConcernOption()) } @@ -72,11 +66,10 @@ var _ bson.Unmarshaler = (*sessionOptions)(nil) func (so *sessionOptions) UnmarshalBSON(data []byte) error { var temp struct { - Causal *bool `bson:"causalConsistency"` - MaxCommitTimeMS *int64 `bson:"maxCommitTimeMS"` - TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` - Snapshot *bool `bson:"snapshot"` - Extra map[string]interface{} `bson:",inline"` + Causal *bool `bson:"causalConsistency"` + TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` + Snapshot *bool `bson:"snapshot"` + Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { return fmt.Errorf("error unmarshalling to temporary sessionOptions object: %v", err) @@ -89,10 +82,6 @@ func (so *sessionOptions) UnmarshalBSON(data []byte) error { if temp.Causal != nil { so.SetCausalConsistency(*temp.Causal) } - if temp.MaxCommitTimeMS != nil { - mctms := time.Duration(*temp.MaxCommitTimeMS) * time.Millisecond - so.SetDefaultMaxCommitTime(&mctms) - } if temp.TxnOptions != nil { if rc := temp.TxnOptions.ReadConcern; rc != nil { so.SetDefaultReadConcern(rc) diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index a5dbc3e75a7..dfa9a124d5f 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -436,7 +436,6 @@ func waitForEvent(ctx context.Context, args waitForEventArguments) error { if args.eventCompleted(client) { return nil } - } time.Sleep(100 * time.Millisecond) diff --git a/internal/integration/unified/unified_spec_runner.go b/internal/integration/unified/unified_spec_runner.go index af71caa44af..45f2c164fa0 100644 --- a/internal/integration/unified/unified_spec_runner.go +++ b/internal/integration/unified/unified_spec_runner.go @@ -45,6 +45,13 @@ var ( "listSearchIndexes ignores read and write concern": "Sync GODRIVER-3074, but skip testing bug GODRIVER-3043", "updateSearchIndex ignores the read and write concern": "Sync GODRIVER-3074, but skip testing bug GODRIVER-3043", + // TODO(DRIVERS-2829): Create CSOT Legacy Timeout Analogues and Compatibility Field + "Reset server and pool after network timeout error during authentication": "Uses unsupported socketTimeoutMS", + "Ignore network timeout error on find": "Uses unsupported socketTimeoutMS", + "A successful find with options": "Uses unsupported maxTimeMS", + "estimatedDocumentCount with maxTimeMS": "Uses unsupported maxTimeMS", + "supports configuring getMore maxTimeMS": "Uses unsupported maxTimeMS", + // TODO(GODRIVER-3137): Implement Gossip cluster time" "unpin after TransientTransactionError error on commit": "Implement GODRIVER-3137", diff --git a/internal/integration/unified_runner_events_helper_test.go b/internal/integration/unified_runner_events_helper_test.go index 2937882dcd0..2fc9c22cbe7 100644 --- a/internal/integration/unified_runner_events_helper_test.go +++ b/internal/integration/unified_runner_events_helper_test.go @@ -88,31 +88,23 @@ func waitForEvent(mt *mtest.T, test *testCase, op *operation) { eventType := op.Arguments.Lookup("event").StringValue() expectedCount := int(op.Arguments.Lookup("count").Int32()) - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - var count int - // Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being. - if eventType == "ServerMarkedUnknownEvent" { - count = test.monitor.getServerMarkedUnknownCount() - } else { - count = test.monitor.getPoolEventCount(eventType) - } - - if count >= expectedCount { - return - } - time.Sleep(100 * time.Millisecond) + callback := func() bool { + var count int + // Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being. + if eventType == "ServerMarkedUnknownEvent" { + count = test.monitor.getServerMarkedUnknownCount() + } else { + count = test.monitor.getPoolEventCount(eventType) } + + return count >= expectedCount } - assert.Soon(mt, callback, defaultCallbackTimeout) + assert.Eventually(mt, + callback, + defaultCallbackTimeout, + 100*time.Millisecond, + "expected spec tests to only wait for Server Marked Unknown SDAM events") } func assertEventCount(mt *mtest.T, testCase *testCase, op *operation) { @@ -135,23 +127,16 @@ func recordPrimary(mt *mtest.T, testCase *testCase) { } func waitForPrimaryChange(mt *mtest.T, testCase *testCase, op *operation) { - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - if getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary { - return - } - } + callback := func() bool { + return getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary } timeout := convertValueToMilliseconds(mt, op.Arguments.Lookup("timeoutMS")) - assert.Soon(mt, callback, timeout) + assert.Eventually(mt, + callback, + timeout, + 100*time.Millisecond, + "expected primary address to be different within the timeout period") } // getPrimaryAddress returns the address of the current primary. If failFast is true, the server selection fast path diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index bcfb5d72f83..bcc97526cbf 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -777,10 +777,9 @@ func executeCollectionOperation(mt *mtest.T, op *operation, sess *mongo.Session) } return err case "dropIndex": - res, err := executeDropIndex(mt, sess, op.Arguments) + err := executeDropIndex(mt, sess, op.Arguments) if op.opError == nil && err == nil { assert.Nil(mt, op.Result, "unexpected result for dropIndex: %v", op.Result) - assert.NotNil(mt, res, "expected result from dropIndex operation, got nil") } return err case "listIndexNames", "mapReduce": diff --git a/internal/logger/io_sink.go b/internal/logger/io_sink.go index c5ff1474b4f..0a6c1bdcabf 100644 --- a/internal/logger/io_sink.go +++ b/internal/logger/io_sink.go @@ -9,6 +9,7 @@ package logger import ( "encoding/json" "io" + "math" "sync" "time" ) @@ -36,7 +37,11 @@ func NewIOSink(out io.Writer) *IOSink { // Info will write a JSON-encoded message to the io.Writer. func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) { - kvMap := make(map[string]interface{}, len(keysAndValues)/2+2) + mapSize := len(keysAndValues) / 2 + if math.MaxInt-mapSize >= 2 { + mapSize += 2 + } + kvMap := make(map[string]interface{}, mapSize) kvMap[KeyTimestamp] = time.Now().UnixNano() kvMap[KeyMessage] = msg diff --git a/bson/bsonoptions/doc.go b/internal/ptrutil/ptr.go similarity index 57% rename from bson/bsonoptions/doc.go rename to internal/ptrutil/ptr.go index c40973c8d43..bf64aad1784 100644 --- a/bson/bsonoptions/doc.go +++ b/internal/ptrutil/ptr.go @@ -1,8 +1,12 @@ -// Copyright (C) MongoDB, Inc. 2022-present. +// Copyright (C) MongoDB, Inc. 2024-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -// Package bsonoptions defines the optional configurations for the BSON codecs. -package bsonoptions +package ptrutil + +// Ptr will return the memory location of the given value. +func Ptr[T any](val T) *T { + return &val +} diff --git a/mongo/batch_cursor.go b/mongo/batch_cursor.go index 9e87b00ae47..a50fa899cf0 100644 --- a/mongo/batch_cursor.go +++ b/mongo/batch_cursor.go @@ -40,13 +40,13 @@ type batchCursor interface { // the cursor that implements it. SetBatchSize(int32) - // SetMaxTime will set the maximum amount of time the server will allow + // SetMaxAwaitTime will set the maximum amount of time the server will allow // the operations to execute. The server will error if this field is set // but the cursor is not configured with awaitData=true. // // The time.Duration value passed by this setter will be converted and // rounded down to the nearest millisecond. - SetMaxTime(time.Duration) + SetMaxAwaitTime(time.Duration) // SetComment will set a user-configurable comment that can be used to // identify the operation in server logs. diff --git a/mongo/change_stream.go b/mongo/change_stream.go index cc051b5f081..f02010f53f9 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -151,6 +151,33 @@ func mergeChangeStreamOptions(opts ...*options.ChangeStreamOptions) *options.Cha return csOpts } +// validChangeStreamTimeouts will return "false" if maxAwaitTimeMS is set, +// timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or +// equal to timeoutMS. Otherwise, the timeouts are valid. +func validChangeStreamTimeouts(ctx context.Context, cs *ChangeStream) bool { + if cs.options == nil || cs.client == nil { + return true + } + + maxAwaitTime := cs.options.MaxAwaitTime + timeout := cs.client.timeout + + if maxAwaitTime == nil { + return true + } + + if deadline, ok := ctx.Deadline(); ok { + ctxTimeout := time.Until(deadline) + timeout = &ctxTimeout + } + + if timeout == nil { + return true + } + + return *timeout <= 0 || *maxAwaitTime < *timeout +} + func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { if ctx == nil { @@ -161,12 +188,14 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in cursorOpts.MarshalValueEncoderFn = newEncoderFn(config.bsonOpts, config.registry) + changeStreamOpts := mergeChangeStreamOptions(opts...) + cs := &ChangeStream{ client: config.client, bsonOpts: config.bsonOpts, registry: config.registry, streamType: config.streamType, - options: mergeChangeStreamOptions(opts...), + options: changeStreamOpts, selector: &serverselector.Composite{ Selectors: []description.ServerSelector{ &serverselector.ReadPref{ReadPref: config.readPreference}, @@ -208,7 +237,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in cs.cursorOptions.BatchSize = *cs.options.BatchSize } if cs.options.MaxAwaitTime != nil { - cs.cursorOptions.MaxTimeMS = int64(*cs.options.MaxAwaitTime / time.Millisecond) + cs.cursorOptions.SetMaxAwaitTime(*cs.options.MaxAwaitTime) } if cs.options.Custom != nil { // Marshal all custom options before passing to the initial aggregate. Return @@ -297,10 +326,18 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err var server driver.Server var conn *mnet.Connection - if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil { + // Apply the client-level timeout if the operation-level timeout is not set. + ctx, cancel := csot.WithTimeout(ctx, cs.client.timeout) + defer cancel() + + connCtx, cancel := csot.WithServerSelectionTimeout(ctx, cs.client.deployment.GetServerSelectionTimeout()) + defer cancel() + + if server, cs.err = cs.client.deployment.SelectServer(connCtx, cs.selector); cs.err != nil { return cs.Err() } - if conn, cs.err = server.Connection(ctx); cs.err != nil { + + if conn, cs.err = server.Connection(connCtx); cs.err != nil { return cs.Err() } defer conn.Close() @@ -329,17 +366,6 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err cs.aggregate.Pipeline(plArr) } - // If no deadline is set on the passed-in context, cs.client.timeout is set, and context is not already - // a Timeout context, honor cs.client.timeout in new Timeout context for change stream operation execution - // and potential retry. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *cs.client.timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of executeOperation to avoid a context leak. - defer cancelFunc() - } - // Execute the aggregate, retrying on retryable errors once (1) if retryable reads are enabled and // infinitely (-1) if context is a Timeout context. var retries int @@ -366,16 +392,20 @@ AggregateExecuteLoop: break AggregateExecuteLoop } + connCtx, cancel := csot.WithServerSelectionTimeout(ctx, cs.client.deployment.GetServerSelectionTimeout()) + defer cancel() + // If error is retryable: subtract 1 from retries, redo server selection, checkout // a connection, and restart loop. retries-- - server, err = cs.client.deployment.SelectServer(ctx, cs.selector) + server, err = cs.client.deployment.SelectServer(connCtx, cs.selector) if err != nil { break AggregateExecuteLoop } conn.Close() - conn, err = server.Connection(ctx) + + conn, err = server.Connection(connCtx) if err != nil { break AggregateExecuteLoop } @@ -646,26 +676,35 @@ func (cs *ChangeStream) ResumeToken() bson.Raw { return cs.resumeToken } -// Next gets the next event for this change stream. It returns true if there were no errors and the next event document -// is available. +// Next gets the next event for this change stream. It returns true if there +// were no errors and the next event document is available. // -// Next blocks until an event is available, an error occurs, or ctx expires. If ctx expires, the error -// will be set to ctx.Err(). In an error case, Next will return false. +// Next blocks until an event is available, an error occurs, or ctx expires. +// If ctx expires, the error will be set to ctx.Err(). In an error case, Next +// will return false. // // If Next returns false, subsequent calls will also return false. func (cs *ChangeStream) Next(ctx context.Context) bool { return cs.next(ctx, false) } -// TryNext attempts to get the next event for this change stream. It returns true if there were no errors and the next -// event document is available. +// TryNext attempts to get the next event for this change stream. It returns +// true if there were no errors and the next event document is available. +// +// TryNext returns false if the change stream is closed by the server, an error +// occurs when getting changes from the server, the next change is not yet +// available, or ctx expires. // -// TryNext returns false if the change stream is closed by the server, an error occurs when getting changes from the -// server, the next change is not yet available, or ctx expires. If ctx expires, the error will be set to ctx.Err(). +// If ctx expires, the error will be set to ctx.Err(). Users can either call +// TryNext again or close the existing change stream and create a new one. It is +// suggested to close and re-create the stream with ah higher timeout if the +// timeout occurs before any events have been received, which is a signal that +// the server is timing out before it can finish processing the existing oplog. // -// If TryNext returns false and an error occurred or the change stream was closed -// (i.e. cs.Err() != nil || cs.ID() == 0), subsequent attempts will also return false. Otherwise, it is safe to call -// TryNext again until a change is available. +// If TryNext returns false and an error occurred or the change stream was +// closed (i.e. cs.Err() != nil || cs.ID() == 0), subsequent attempts will also +// return false. Otherwise, it is safe to call TryNext again until a change is +// available. // // This method requires driver version >= 1.2.0. func (cs *ChangeStream) TryNext(ctx context.Context) bool { @@ -703,6 +742,18 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool { } func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { + if !validChangeStreamTimeouts(ctx, cs) { + cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + + return + } + + // Apply the client-level timeout if the operation-level timeout is not set. + // This calculation is also done in "executeOperation" but cursor.Next is also + // blocking and should honor client-level timeouts. + ctx, cancel := csot.WithTimeout(ctx, cs.client.timeout) + defer cancel() + for { if cs.cursor == nil { return diff --git a/mongo/change_stream_deployment.go b/mongo/change_stream_deployment.go index 64f30095c84..b4fdbd26904 100644 --- a/mongo/change_stream_deployment.go +++ b/mongo/change_stream_deployment.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "time" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/description" @@ -25,6 +26,7 @@ var _ driver.Server = (*changeStreamDeployment)(nil) var _ driver.ErrorProcessor = (*changeStreamDeployment)(nil) func (c *changeStreamDeployment) SelectServer(context.Context, description.ServerSelector) (driver.Server, error) { + return c, nil } @@ -48,3 +50,9 @@ func (c *changeStreamDeployment) ProcessError(err error, describer mnet.Describe return ep.ProcessError(err, describer) } + +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for change stream deployments. +func (*changeStreamDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index fa447135932..5b1193a0daf 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -7,7 +7,9 @@ package mongo import ( + "context" "testing" + "time" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/mongo/options" @@ -15,7 +17,7 @@ import ( func TestChangeStream(t *testing.T) { t.Run("nil cursor", func(t *testing.T) { - cs := &ChangeStream{} + cs := &ChangeStream{client: &Client{}} id := cs.ID() assert.Equal(t, int64(0), id, "expected ID 0, got %v", id) @@ -90,3 +92,96 @@ func TestMergeChangeStreamOptions(t *testing.T) { }) } } + +func TestValidChangeStreamTimeouts(t *testing.T) { + t.Parallel() + + newDurPtr := func(dur time.Duration) *time.Duration { + return &dur + } + + tests := []struct { + name string + parent context.Context + maxAwaitTimeout, timeout *time.Duration + wantTimeout time.Duration + want bool + }{ + { + name: "no context deadline and no timeouts", + parent: context.Background(), + maxAwaitTimeout: nil, + timeout: nil, + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTimeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: nil, + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and timeout", + parent: context.Background(), + maxAwaitTimeout: nil, + timeout: newDurPtr(1), + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTime gt timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(2), + timeout: newDurPtr(1), + wantTimeout: 0, + want: false, + }, + { + name: "no context deadline and maxAwaitTime lt timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: newDurPtr(2), + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTime eq timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: newDurPtr(1), + wantTimeout: 0, + want: false, + }, + { + name: "no context deadline and maxAwaitTime with negative timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: newDurPtr(-1), + wantTimeout: 0, + want: true, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cs := &ChangeStream{ + options: &options.ChangeStreamOptions{ + MaxAwaitTime: test.maxAwaitTimeout, + }, + client: &Client{ + timeout: test.timeout, + }, + } + + got := validChangeStreamTimeouts(test.parent, cs) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/mongo/client.go b/mongo/client.go index e68714c33f4..d3e00bef171 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -180,7 +180,6 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { if clientOpt.RetryReads != nil { client.retryReads = *clientOpt.RetryReads } - // Timeout client.timeout = clientOpt.Timeout client.httpClient = clientOpt.HTTPClient // WriteConcern @@ -206,15 +205,17 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt.SetMaxPoolSize(defaultMaxPoolSize) } + cfg, err := topology.NewConfig(clientOpt, client.clock) if err != nil { return nil, err } - cfg, err := topology.NewConfig(clientOpt, client.clock) - if err != nil { - return nil, err + var connectTimeout time.Duration + if clientOpt.ConnectTimeout != nil { + connectTimeout = *clientOpt.ConnectTimeout } - client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts) + + client.serverAPI = topology.ServerAPIFromServerOptions(connectTimeout, cfg.ServerOpts) if client.deployment == nil { client.deployment, err = topology.New(cfg) @@ -396,9 +397,6 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (*Session, error) if opt.DefaultWriteConcern != nil { sopts.DefaultWriteConcern = opt.DefaultWriteConcern } - if opt.DefaultMaxCommitTime != nil { - sopts.DefaultMaxCommitTime = opt.DefaultMaxCommitTime - } if opt.Snapshot != nil { sopts.Snapshot = opt.Snapshot } @@ -423,9 +421,6 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (*Session, error) if sopts.DefaultReadPreference != nil { coreOpts.DefaultReadPreference = sopts.DefaultReadPreference } - if sopts.DefaultMaxCommitTime != nil { - coreOpts.DefaultMaxCommitTime = sopts.DefaultMaxCommitTime - } if sopts.Snapshot != nil { coreOpts.Snapshot = sopts.Snapshot } diff --git a/mongo/client_encryption.go b/mongo/client_encryption.go index 97fe9b27b76..bc34f197401 100644 --- a/mongo/client_encryption.go +++ b/mongo/client_encryption.go @@ -108,6 +108,7 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context, } r := bson.NewValueReader(efBSON) dec := bson.NewDecoder(r) + dec.DefaultDocumentM() var m bson.M err = dec.Decode(&m) if err != nil { diff --git a/mongo/client_test.go b/mongo/client_test.go index e5d08642b33..6e7607be91c 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -501,4 +501,13 @@ func TestClient(t *testing.T) { }) } }) + t.Run("negative timeout will err", func(t *testing.T) { + t.Parallel() + + copts := options.Client().SetTimeout(-1 * time.Second) + _, err := Connect(copts) + + errmsg := `invalid value "-1s" for "Timeout": value must be positive` + assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error()) + }) } diff --git a/mongo/collection.go b/mongo/collection.go index c24ab9273fa..ea75a4a0cd2 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -945,9 +945,6 @@ func mergeAggregateOptions(opts ...*options.AggregateOptions) *options.Aggregate if ao.Collation != nil { aggOpts.Collation = ao.Collation } - if ao.MaxTime != nil { - aggOpts.MaxTime = ao.MaxTime - } if ao.MaxAwaitTime != nil { aggOpts.MaxAwaitTime = ao.MaxAwaitTime } @@ -1032,8 +1029,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { Crypt(a.client.cryptFLE). ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). - Timeout(a.client.timeout). - MaxTime(ao.MaxTime) + Timeout(a.client.timeout) if ao.AllowDiskUse != nil { op.AllowDiskUse(*ao.AllowDiskUse) @@ -1050,7 +1046,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { op.Collation(bsoncore.Document(ao.Collation.ToDocument())) } if ao.MaxAwaitTime != nil { - cursorOpts.MaxTimeMS = int64(*ao.MaxAwaitTime / time.Millisecond) + cursorOpts.SetMaxAwaitTime(*ao.MaxAwaitTime) } if ao.Comment != nil { comment, err := marshalValue(ao.Comment, a.bsonOpts, a.registry) @@ -1147,9 +1143,6 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, if co.Limit != nil { countOpts.Limit = co.Limit } - if co.MaxTime != nil { - countOpts.MaxTime = co.MaxTime - } if co.Skip != nil { countOpts.Skip = co.Skip } @@ -1178,7 +1171,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime) + Timeout(coll.client.timeout) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -1269,16 +1262,13 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, if opt.Comment != nil { co.Comment = opt.Comment } - if opt.MaxTime != nil { - co.MaxTime = opt.MaxTime - } } selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(co.MaxTime) + Timeout(coll.client.timeout) if co.Comment != nil { comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) @@ -1352,9 +1342,6 @@ func (coll *Collection) Distinct( if do.Comment != nil { option.Comment = do.Comment } - if do.MaxTime != nil { - option.MaxTime = do.MaxTime - } } op := operation.NewDistinct(fieldName, f). @@ -1362,7 +1349,7 @@ func (coll *Collection) Distinct( Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(option.MaxTime) + Timeout(coll.client.timeout) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1439,9 +1426,6 @@ func mergeFindOptions(opts ...*options.FindOptions) *options.FindOptions { if opt.MaxAwaitTime != nil { fo.MaxAwaitTime = opt.MaxAwaitTime } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Min != nil { fo.Min = opt.Min } @@ -1517,7 +1501,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger) cursorOpts := coll.client.createBaseCursorOptions() @@ -1588,7 +1572,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Max(max) } if fo.MaxAwaitTime != nil { - cursorOpts.MaxTimeMS = int64(*fo.MaxAwaitTime / time.Millisecond) + cursorOpts.SetMaxAwaitTime(*fo.MaxAwaitTime) } if fo.Min != nil { min, err := marshal(fo.Min, coll.bsonOpts, coll.registry) @@ -1656,7 +1640,6 @@ func newFindOptionsFromFindOneOptions(opts ...*options.FindOneOptions) []*option Comment: opt.Comment, Hint: opt.Hint, Max: opt.Max, - MaxTime: opt.MaxTime, Min: opt.Min, Projection: opt.Projection, ReturnKey: opt.ReturnKey, @@ -1769,9 +1752,6 @@ func mergeFindOneAndDeleteOptions(opts ...*options.FindOneAndDeleteOptions) *opt if opt.Comment != nil { fo.Comment = opt.Comment } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Projection != nil { fo.Projection = opt.Projection } @@ -1808,8 +1788,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} return &SingleResult{err: err} } fod := mergeFindOneAndDeleteOptions(opts...) - op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fod.MaxTime) + op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } @@ -1875,9 +1854,6 @@ func mergeFindOneAndReplaceOptions(opts ...*options.FindOneAndReplaceOptions) *o if opt.Comment != nil { fo.Comment = opt.Comment } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Projection != nil { fo.Projection = opt.Projection } @@ -1932,7 +1908,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := mergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsoncore.TypeEmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } @@ -2010,9 +1986,6 @@ func mergeFindOneAndUpdateOptions(opts ...*options.FindOneAndUpdateOptions) *opt if opt.Comment != nil { fo.Comment = opt.Comment } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Projection != nil { fo.Projection = opt.Projection } @@ -2064,8 +2037,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} } fo := mergeFindOneAndUpdateOptions(opts...) - op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fo.MaxTime) + op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index 9ef1a63acdd..658ff2451c2 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -201,7 +201,7 @@ func ExampleCollection_Aggregate() { }}, }}, } - opts := options.Aggregate().SetMaxTime(2 * time.Second) + opts := options.Aggregate() cursor, err := coll.Aggregate( context.TODO(), mongo.Pipeline{groupStage}, @@ -264,14 +264,13 @@ func ExampleCollection_BulkWrite() { func ExampleCollection_CountDocuments() { var coll *mongo.Collection + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + // Count the number of times the name "Bob" appears in the collection. - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server. - opts := options.Count().SetMaxTime(2 * time.Second) - count, err := coll.CountDocuments( - context.TODO(), - bson.D{{"name", "Bob"}}, - opts) + count, err := coll.CountDocuments(ctx, bson.D{{"name", "Bob"}}, nil) if err != nil { log.Fatal(err) } @@ -317,13 +316,15 @@ func ExampleCollection_DeleteOne() { func ExampleCollection_Distinct() { var coll *mongo.Collection + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + // Find all unique values for the "name" field for documents in which the // "age" field is greater than 25. - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server. filter := bson.D{{"age", bson.D{{"$gt", 25}}}} - opts := options.Distinct().SetMaxTime(2 * time.Second) - res := coll.Distinct(context.TODO(), "name", filter, opts) + res := coll.Distinct(ctx, "name", filter) if err := res.Err(); err != nil { log.Fatal(err) } @@ -341,11 +342,13 @@ func ExampleCollection_Distinct() { func ExampleCollection_EstimatedDocumentCount() { var coll *mongo.Collection + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + // Get and print an estimated of the number of documents in the collection. - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server. - opts := options.EstimatedDocumentCount().SetMaxTime(2 * time.Second) - count, err := coll.EstimatedDocumentCount(context.TODO(), opts) + count, err := coll.EstimatedDocumentCount(ctx, nil) if err != nil { log.Fatal(err) } @@ -1053,8 +1056,7 @@ func ExampleIndexView_CreateMany() { // Specify the MaxTime option to limit the amount of time the operation can // run on the server - opts := options.CreateIndexes().SetMaxTime(2 * time.Second) - names, err := indexView.CreateMany(context.TODO(), models, opts) + names, err := indexView.CreateMany(context.TODO(), models, nil) if err != nil { log.Fatal(err) } @@ -1065,17 +1067,19 @@ func ExampleIndexView_CreateMany() { func ExampleIndexView_List() { var indexView *mongo.IndexView - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server - opts := options.ListIndexes().SetMaxTime(2 * time.Second) - cursor, err := indexView.List(context.TODO(), opts) + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + + cursor, err := indexView.List(ctx, nil) if err != nil { log.Fatal(err) } // Get a slice of all indexes returned and print them out. var results []bson.M - if err = cursor.All(context.TODO(), &results); err != nil { + if err = cursor.All(ctx, &results); err != nil { log.Fatal(err) } fmt.Println(results) diff --git a/mongo/cursor.go b/mongo/cursor.go index 8f07b1ee9b5..22be70a17cd 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -247,9 +247,6 @@ func getDecoder( if opts.BinaryAsSlice { dec.BinaryAsSlice() } - if opts.DefaultDocumentD { - dec.DefaultDocumentD() - } if opts.DefaultDocumentM { dec.DefaultDocumentM() } @@ -394,14 +391,14 @@ func (c *Cursor) SetBatchSize(batchSize int32) { c.bc.SetBatchSize(batchSize) } -// SetMaxTime will set the maximum amount of time the server will allow the +// SetMaxAwaitTime will set the maximum amount of time the server will allow the // operations to execute. The server will error if this field is set but the // cursor is not configured with awaitData=true. // // The time.Duration value passed by this setter will be converted and rounded // down to the nearest millisecond. -func (c *Cursor) SetMaxTime(dur time.Duration) { - c.bc.SetMaxTime(dur) +func (c *Cursor) SetMaxAwaitTime(dur time.Duration) { + c.bc.SetMaxAwaitTime(dur) } // SetComment will set a user-configurable comment that can be used to identify diff --git a/mongo/cursor_test.go b/mongo/cursor_test.go index be877b7a6cf..45a3247b150 100644 --- a/mongo/cursor_test.go +++ b/mongo/cursor_test.go @@ -95,9 +95,9 @@ func (tbc *testBatchCursor) Close(context.Context) error { return nil } -func (tbc *testBatchCursor) SetBatchSize(int32) {} -func (tbc *testBatchCursor) SetComment(interface{}) {} -func (tbc *testBatchCursor) SetMaxTime(time.Duration) {} +func (tbc *testBatchCursor) SetBatchSize(int32) {} +func (tbc *testBatchCursor) SetComment(interface{}) {} +func (tbc *testBatchCursor) SetMaxAwaitTime(time.Duration) {} func TestCursor(t *testing.T) { t.Run("loops until docs available", func(t *testing.T) {}) diff --git a/mongo/database.go b/mongo/database.go index 4748d3d2b04..36296a11b79 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csfle" + "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -727,10 +728,14 @@ func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, nam // That is OK. This wire version check is a best effort to inform users earlier if using a QEv2 driver with a QEv1 server. { const QEv2WireVersion = 21 + ctx, cancel := csot.WithServerSelectionTimeout(ctx, db.client.deployment.GetServerSelectionTimeout()) + defer cancel() + server, err := db.client.deployment.SelectServer(ctx, &serverselector.Write{}) if err != nil { return fmt.Errorf("error selecting server to check maxWireVersion: %w", err) } + conn, err := server.Connection(ctx) if err != nil { return fmt.Errorf("error getting connection to check maxWireVersion: %w", err) diff --git a/mongo/errors.go b/mongo/errors.go index f3e7bbd43dc..5b2c039898b 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -124,7 +124,6 @@ func IsDuplicateKeyError(err error) bool { var timeoutErrs = [...]error{ context.DeadlineExceeded, driver.ErrDeadlineWouldBeExceeded, - topology.ErrServerSelectionTimeout, } // IsTimeout returns true if err was caused by a timeout. For error chains, diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index e5016a51791..dd3661877b5 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -61,7 +61,8 @@ type upload struct { // filename. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenUploadStream( ctx context.Context, filename string, @@ -74,13 +75,17 @@ func (b *GridFSBucket) OpenUploadStream( // ID and filename. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenUploadStreamWithID( ctx context.Context, fileID interface{}, filename string, opts ...*options.UploadOptions, ) (*GridFSUploadStream, error) { + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + if err := b.checkFirstWrite(ctx); err != nil { return nil, err } @@ -100,7 +105,8 @@ func (b *GridFSBucket) OpenUploadStreamWithID( // bucket that also require a custom deadline. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) UploadFromStream( ctx context.Context, filename string, @@ -119,7 +125,8 @@ func (b *GridFSBucket) UploadFromStream( // bucket that also require a custom deadline. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) UploadFromStreamWithID( ctx context.Context, fileID interface{}, @@ -157,8 +164,9 @@ func (b *GridFSBucket) UploadFromStreamWithID( // OpenDownloadStream creates a stream from which the contents of the file can // be read. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenDownloadStream(ctx context.Context, fileID interface{}) (*GridFSDownloadStream, error) { return b.openDownloadStream(ctx, bson.D{{"_id", fileID}}) } @@ -171,8 +179,9 @@ func (b *GridFSBucket) OpenDownloadStream(ctx context.Context, fileID interface{ // cannot be done concurrently with other read operations operations on this // bucket that also require a custom deadline. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) DownloadToStream(ctx context.Context, fileID interface{}, stream io.Writer) (int64, error) { ds, err := b.OpenDownloadStream(ctx, fileID) if err != nil { @@ -185,8 +194,9 @@ func (b *GridFSBucket) DownloadToStream(ctx context.Context, fileID interface{}, // OpenDownloadStreamByName opens a download stream for the file with the given // filename. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenDownloadStreamByName( ctx context.Context, filename string, @@ -227,8 +237,9 @@ func (b *GridFSBucket) OpenDownloadStreamByName( // cannot be done concurrently with other read operations operations on this // bucket that also require a custom deadline. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) DownloadToStreamByName( ctx context.Context, filename string, @@ -243,23 +254,13 @@ func (b *GridFSBucket) DownloadToStreamByName( return b.downloadToStream(ds, stream) } -// Delete deletes all chunks and metadata associated with the file with the given file ID and runs the underlying -// delete operations with the provided context. -// -// Use the context parameter to time-out or cancel the delete operation. The deadline set by SetWriteDeadline is ignored. +// Delete deletes all chunks and metadata associated with the file with the +// given file ID and runs the underlying delete operations with the provided +// context. func (b *GridFSBucket) Delete(ctx context.Context, fileID interface{}) error { - // If no deadline is set on the passed-in context, Timeout is set on the Client, and context is - // not already a Timeout context, honor Timeout in new Timeout context for operation execution to - // be shared by both delete operations. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of Execute to avoid a context leak. - defer cancelFunc() - } - - // Delete document in files collection and then chunks to minimize race conditions. + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + res, err := b.filesColl.DeleteOne(ctx, bson.D{{"_id", fileID}}) if err == nil && res.DeletedCount == 0 { err = ErrFileNotFound @@ -272,11 +273,8 @@ func (b *GridFSBucket) Delete(ctx context.Context, fileID interface{}) error { return b.deleteChunks(ctx, fileID) } -// Find returns the files collection documents that match the given filter and runs the underlying -// find query with the provided context. -// -// Use the context parameter to time-out or cancel the find operation. The deadline set by SetReadDeadline -// is ignored. +// Find returns the files collection documents that match the given filter and +// runs the underlying find query with the provided context. func (b *GridFSBucket) Find( ctx context.Context, filter interface{}, @@ -296,9 +294,6 @@ func (b *GridFSBucket) Find( if opt.Limit != nil { gfsOpts.Limit = opt.Limit } - if opt.MaxTime != nil { - gfsOpts.MaxTime = opt.MaxTime - } if opt.NoCursorTimeout != nil { gfsOpts.NoCursorTimeout = opt.NoCursorTimeout } @@ -319,9 +314,6 @@ func (b *GridFSBucket) Find( if gfsOpts.Limit != nil { find.SetLimit(int64(*gfsOpts.Limit)) } - if gfsOpts.MaxTime != nil { - find.SetMaxTime(*gfsOpts.MaxTime) - } if gfsOpts.NoCursorTimeout != nil { find.SetNoCursorTimeout(*gfsOpts.NoCursorTimeout) } @@ -336,11 +328,6 @@ func (b *GridFSBucket) Find( } // Rename renames the stored file with the specified file ID. -// -// If this operation requires a custom write deadline to be set on the bucket, it cannot be done concurrently with other -// write operations operations on this bucket that also require a custom deadline -// -// Use SetWriteDeadline to set a deadline for the rename operation. func (b *GridFSBucket) Rename(ctx context.Context, fileID interface{}, newFilename string) error { res, err := b.filesColl.UpdateOne(ctx, bson.D{{"_id", fileID}}, @@ -357,21 +344,11 @@ func (b *GridFSBucket) Rename(ctx context.Context, fileID interface{}, newFilena return nil } -// Drop drops the files and chunks collections associated with this bucket and runs the drop operations with -// the provided context. -// -// Use the context parameter to time-out or cancel the drop operation. The deadline set by SetWriteDeadline is ignored. +// Drop drops the files and chunks collections associated with this bucket and +// runs the drop operations with the provided context. func (b *GridFSBucket) Drop(ctx context.Context) error { - // If no deadline is set on the passed-in context, Timeout is set on the Client, and context is - // not already a Timeout context, honor Timeout in new Timeout context for operation execution to - // be shared by both drop operations. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of Execute to avoid a context leak. - defer cancelFunc() - } + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() err := b.filesColl.Drop(ctx) if err != nil { @@ -396,6 +373,9 @@ func (b *GridFSBucket) openDownloadStream( filter interface{}, opts ...*options.FindOneOptions, ) (*GridFSDownloadStream, error) { + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + result := b.filesColl.FindOne(ctx, filter, opts...) // Unmarshal the data into a File instance, which can be passed to newGridFSDownloadStream. The _id value has to be @@ -425,6 +405,7 @@ func (b *GridFSBucket) openDownloadStream( if err != nil { return nil, err } + // The chunk size can be overridden for individual files, so the expected chunk size should be the "chunkSize" // field from the files collection document, not the bucket's chunk size. return newGridFSDownloadStream(ctx, chunksCursor, foundFile.ChunkSize, foundFile), nil diff --git a/mongo/index_view.go b/mongo/index_view.go index 84f4d71dc4c..231081947a2 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -13,7 +13,6 @@ import ( "fmt" "strconv" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -108,15 +107,12 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption if opt.BatchSize != nil { lio.BatchSize = opt.BatchSize } - if opt.MaxTime != nil { - lio.MaxTime = opt.MaxTime - } } if lio.BatchSize != nil { op = op.BatchSize(*lio.BatchSize) cursorOpts.BatchSize = *lio.BatchSize } - op = op.MaxTime(lio.MaxTime) + retry := driver.RetryNone if iv.coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -269,9 +265,6 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. if opt == nil { continue } - if opt.MaxTime != nil { - option.MaxTime = opt.MaxTime - } if opt.CommitQuorum != nil { option.CommitQuorum = opt.CommitQuorum } @@ -281,7 +274,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) if option.CommitQuorum != nil { commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { @@ -383,7 +376,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum return optsDoc, nil } -func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { +func (iv IndexView) drop(ctx context.Context, name string, _ ...*options.DropIndexesOptions) error { if ctx == nil { ctx = context.Background() } @@ -396,7 +389,7 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop err := iv.coll.client.validSession(sess) if err != nil { - return nil, err + return err } wc := iv.coll.writeConcern @@ -409,48 +402,35 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop selector := makePinnedSelector(sess, iv.coll.writeSelector) - dio := options.DropIndexes() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.MaxTime != nil { - dio.MaxTime = opt.MaxTime - } - } op := operation.NewDropIndexes(name). Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) err = op.Execute(ctx) if err != nil { - return nil, replaceErrors(err) + return replaceErrors(err) } - // TODO: it's weird to return a bson.Raw here because we have to convert the result back to BSON - ridx, res := bsoncore.AppendDocumentStart(nil) - res = bsoncore.AppendInt32Element(res, "nIndexesWas", op.Result().NIndexesWas) - res, _ = bsoncore.AppendDocumentEnd(res, ridx) - return res, nil + return nil } -// DropOne executes a dropIndexes operation to drop an index on the collection. If the operation succeeds, this returns -// a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the number of -// indexes that existed prior to the drop. +// DropOne executes a dropIndexes operation to drop an index on the collection. // -// The name parameter should be the name of the index to drop. If the name is "*", ErrMultipleIndexDrop will be returned -// without running the command because doing so would drop all indexes. +// The name parameter should be the name of the index to drop. If the name is +// "*", ErrMultipleIndexDrop will be returned without running the command +// because doing so would drop all indexes. // -// The opts parameter can be used to specify options for this operation (see the options.DropIndexesOptions -// documentation). +// The opts parameter can be used to specify options for this operation (see the +// options.DropIndexesOptions documentation). // -// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/dropIndexes/. -func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { +// For more information about the command, see +// https://www.mongodb.com/docs/manual/reference/command/dropIndexes/. +func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.DropIndexesOptions) error { if name == "*" { - return nil, ErrMultipleIndexDrop + return ErrMultipleIndexDrop } return iv.drop(ctx, name, opts...) @@ -465,9 +445,7 @@ func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.D // For more information about the command, see // https://www.mongodb.com/docs/manual/reference/command/dropIndexes/. func (iv IndexView) DropAll(ctx context.Context, opts ...*options.DropIndexesOptions) error { - _, err := iv.drop(ctx, "*", opts...) - - return err + return iv.drop(ctx, "*", opts...) } func getOrGenerateIndexName(keySpecDocument bsoncore.Document, model IndexModel) (string, error) { diff --git a/mongo/mongo.go b/mongo/mongo.go index 318c7650009..7ba0dff24e4 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -229,8 +229,8 @@ func marshalAggregatePipeline( if err != nil { return nil, false, err } - if btype != bson.TypeArray { - return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bson.TypeArray) + if typ := bson.Type(btype); typ != bson.TypeArray { + return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", typ, bson.TypeArray) } var hasOutputStage bool diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index 3c92fdb6207..736594a6e61 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -616,6 +616,6 @@ type bvMarsh struct { err error } -func (b bvMarsh) MarshalBSONValue() (bson.Type, []byte, error) { - return b.t, b.data, b.err +func (b bvMarsh) MarshalBSONValue() (byte, []byte, error) { + return byte(b.t), b.data, b.err } diff --git a/mongo/options/aggregateoptions.go b/mongo/options/aggregateoptions.go index 6a8c26faab1..2c068a582e3 100644 --- a/mongo/options/aggregateoptions.go +++ b/mongo/options/aggregateoptions.go @@ -32,14 +32,6 @@ type AggregateOptions struct { // default value is nil, which means the default collation of the collection will be used. Collation *Collation - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // The maximum amount of time that the server should wait for new documents to satisfy a tailable cursor query. // This option is only valid for MongoDB versions >= 3.2 and is ignored for previous server versions. MaxAwaitTime *time.Duration @@ -95,16 +87,6 @@ func (ao *AggregateOptions) SetCollation(c *Collation) *AggregateOptions { return ao } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (ao *AggregateOptions) SetMaxTime(d time.Duration) *AggregateOptions { - ao.MaxTime = &d - return ao -} - // SetMaxAwaitTime sets the value for the MaxAwaitTime field. func (ao *AggregateOptions) SetMaxAwaitTime(d time.Duration) *AggregateOptions { ao.MaxAwaitTime = &d diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index b5fe1931eaa..aee3df998b5 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "io/ioutil" + "math" "net" "net/http" "strings" @@ -170,11 +171,6 @@ type BSONOptions struct { // instead of a primitive.Binary. BinaryAsSlice bool - // DefaultDocumentD causes the driver to always unmarshal documents into the - // primitive.D type. This behavior is restricted to data typed as - // "interface{}" or "map[string]interface{}". - DefaultDocumentD bool - // DefaultDocumentM causes the driver to always unmarshal documents into the // primitive.M type. This behavior is restricted to data typed as // "interface{}" or "map[string]interface{}". @@ -251,13 +247,6 @@ type ClientOptions struct { // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any // release. Deployment driver.Deployment - - // SocketTimeout specifies the timeout to be used for the Client's socket reads and writes. - // - // NOTE(benjirewis): SocketTimeout will be deprecated in a future release. The more general Timeout option - // may be used in its place to control the amount of time that a single operation can run before returning - // an error. Setting SocketTimeout and Timeout on a single client will result in undefined behavior. - SocketTimeout *time.Duration } // Client creates a new ClientOptions instance. @@ -325,6 +314,10 @@ func (c *ClientOptions) validate() error { return fmt.Errorf("invalid server monitoring mode: %q", *mode) } + if to := c.Timeout; to != nil && *to < 0 { + return fmt.Errorf(`invalid value %q for "Timeout": value must be positive`, *to) + } + return nil } @@ -478,10 +471,6 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { c.ServerSelectionTimeout = &cs.ServerSelectionTimeout } - if cs.SocketTimeoutSet { - c.SocketTimeout = &cs.SocketTimeout - } - if cs.SRVMaxHosts != 0 { c.SRVMaxHosts = &cs.SRVMaxHosts } @@ -531,7 +520,7 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { c.TLSConfig = tlsConfig } - if cs.JSet || cs.WString != "" || cs.WNumberSet || cs.WTimeoutSet { + if cs.JSet || cs.WString != "" || cs.WNumberSet { c.WriteConcern = &writeconcern.WriteConcern{} if len(cs.WString) > 0 { @@ -543,10 +532,6 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { if cs.JSet { c.WriteConcern.Journal = &cs.J } - - if cs.WTimeoutSet { - c.WriteConcern.WTimeout = cs.WTimeout - } } if cs.ZlibLevelSet { @@ -831,29 +816,19 @@ func (c *ClientOptions) SetServerSelectionTimeout(d time.Duration) *ClientOption return c } -// SetSocketTimeout specifies how long the driver will wait for a socket read or write to return before returning a -// network error. This can also be set through the "socketTimeoutMS" URI option (e.g. "socketTimeoutMS=1000"). The -// default value is 0, meaning no timeout is used and socket operations can block indefinitely. -// -// NOTE(benjirewis): SocketTimeout will be deprecated in a future release. The more general Timeout option may be used -// in its place to control the amount of time that a single operation can run before returning an error. Setting -// SocketTimeout and Timeout on a single client will result in undefined behavior. -func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions { - c.SocketTimeout = &d - return c -} - -// SetTimeout specifies the amount of time that a single operation run on this Client can execute before returning an error. -// The deadline of any operation run through the Client will be honored above any Timeout set on the Client; Timeout will only -// be honored if there is no deadline on the operation Context. Timeout can also be set through the "timeoutMS" URI option -// (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not inherit a timeout from the Client. +// SetTimeout specifies the amount of time that a single operation run on this +// Client can execute before returning an error. The deadline of any operation +// run through the Client will be honored above any Timeout set on the Client; +// Timeout will only be honored if there is no deadline on the operation +// Context. Timeout can also be set through the "timeoutMS" URI option +// (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not +// inherit a timeout from the Client. // -// If any Timeout is set (even 0) on the Client, the values of MaxTime on operation options, TransactionOptions.MaxCommitTime and -// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and SocketTimeout or WriteConcern.wTimeout will result -// in undefined behavior. +// The value for a Timeout must be positive. // -// NOTE(benjirewis): SetTimeout represents unstable, provisional API. The behavior of the driver when a Timeout is specified is -// subject to change. +// If any Timeout is set (even 0) on the Client, the values of MaxTime on +// operation options, TransactionOptions.MaxCommitTime and +// SessionOptions.DefaultMaxCommitTime will be ignored. func (c *ClientOptions) SetTimeout(d time.Duration) *ClientOptions { c.Timeout = &d return c @@ -1088,9 +1063,6 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.Direct != nil { c.Direct = opt.Direct } - if opt.SocketTimeout != nil { - c.SocketTimeout = opt.SocketTimeout - } if opt.SRVMaxHosts != nil { c.SRVMaxHosts = opt.SRVMaxHosts } @@ -1166,7 +1138,19 @@ func addClientCertFromSeparateFiles(cfg *tls.Config, keyFile, certFile, keyPassw return "", err } - data := make([]byte, 0, len(keyData)+len(certData)+1) + keySize := len(keyData) + if keySize > 64*1024*1024 { + return "", errors.New("X.509 key must be less than 64 MiB") + } + certSize := len(certData) + if certSize > 64*1024*1024 { + return "", errors.New("X.509 certificate must be less than 64 MiB") + } + dataSize := int64(keySize) + int64(certSize) + 1 + if dataSize > math.MaxInt { + return "", errors.New("size overflow") + } + data := make([]byte, 0, int(dataSize)) data = append(data, keyData...) data = append(data, '\n') data = append(data, certData...) diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index beba45514f6..70131ded575 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -85,7 +85,6 @@ func TestClientOptions(t *testing.T) { {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true}, {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true}, {"Direct", (*ClientOptions).SetDirect, true, "Direct", true}, - {"SocketTimeout", (*ClientOptions).SetSocketTimeout, 5 * time.Second, "SocketTimeout", true}, {"TLSConfig", (*ClientOptions).SetTLSConfig, &tls.Config{}, "TLSConfig", false}, {"WriteConcern", (*ClientOptions).SetWriteConcern, writeconcern.Majority(), "WriteConcern", false}, {"ZlibLevel", (*ClientOptions).SetZlibLevel, 6, "ZlibLevel", true}, @@ -390,11 +389,6 @@ func TestClientOptions(t *testing.T) { "mongodb://localhost/?serverSelectionTimeoutMS=45000", baseClient().SetServerSelectionTimeout(45 * time.Second), }, - { - "SocketTimeout", - "mongodb://localhost/?socketTimeoutMS=15000", - baseClient().SetSocketTimeout(15 * time.Second), - }, { "TLS CACertificate", "mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/ca.pem", @@ -440,11 +434,6 @@ func TestClientOptions(t *testing.T) { "mongodb://localhost/?w=3", baseClient().SetWriteConcern(&writeconcern.WriteConcern{W: 3}), }, - { - "WriteConcern WTimeout", - "mongodb://localhost/?wTimeoutMS=45000", - baseClient().SetWriteConcern(&writeconcern.WriteConcern{WTimeout: 45 * time.Second}), - }, { "ZLibLevel", "mongodb://localhost/?zlibCompressionLevel=4", diff --git a/mongo/options/countoptions.go b/mongo/options/countoptions.go index a47550f6d2e..7321bb27431 100644 --- a/mongo/options/countoptions.go +++ b/mongo/options/countoptions.go @@ -6,8 +6,6 @@ package options -import "time" - // CountOptions represents options that can be used to configure a CountDocuments operation. type CountOptions struct { // Specifies a collation to use for string comparisons during the operation. This option is only valid for MongoDB @@ -28,14 +26,6 @@ type CountOptions struct { // documents matching the filter will be counted. Limit *int64 - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there is - // no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used in - // its place to control the amount of time that a single operation can run before returning an error. MaxTime is - // ignored if Timeout is set on the client. - MaxTime *time.Duration - // The number of documents to skip before counting. The default value is 0. Skip *int64 } @@ -69,16 +59,6 @@ func (co *CountOptions) SetLimit(i int64) *CountOptions { return co } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (co *CountOptions) SetMaxTime(d time.Duration) *CountOptions { - co.MaxTime = &d - return co -} - // SetSkip sets the value for the Skip field. func (co *CountOptions) SetSkip(i int64) *CountOptions { co.Skip = &i diff --git a/mongo/options/distinctoptions.go b/mongo/options/distinctoptions.go index 4cfcb98526c..33efd580063 100644 --- a/mongo/options/distinctoptions.go +++ b/mongo/options/distinctoptions.go @@ -6,8 +6,6 @@ package options -import "time" - // DistinctOptions represents options that can be used to configure a Distinct operation. type DistinctOptions struct { // Specifies a collation to use for string comparisons during the operation. This option is only valid for MongoDB @@ -18,14 +16,6 @@ type DistinctOptions struct { // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be - // used in its place to control the amount of time that a single operation can run before returning an error. - // MaxTime is ignored if Timeout is set on the client. - MaxTime *time.Duration } // Distinct creates a new DistinctOptions instance. @@ -44,13 +34,3 @@ func (do *DistinctOptions) SetComment(comment interface{}) *DistinctOptions { do.Comment = comment return do } - -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (do *DistinctOptions) SetMaxTime(d time.Duration) *DistinctOptions { - do.MaxTime = &d - return do -} diff --git a/mongo/options/estimatedcountoptions.go b/mongo/options/estimatedcountoptions.go index b7d52bef6d5..5f32ab13ba5 100644 --- a/mongo/options/estimatedcountoptions.go +++ b/mongo/options/estimatedcountoptions.go @@ -6,21 +6,11 @@ package options -import "time" - // EstimatedDocumentCountOptions represents options that can be used to configure an EstimatedDocumentCount operation. type EstimatedDocumentCountOptions struct { // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace // the operation. The default is nil, which means that no comment will be included in the logs. Comment interface{} - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration } // EstimatedDocumentCount creates a new EstimatedDocumentCountOptions instance. @@ -33,13 +23,3 @@ func (eco *EstimatedDocumentCountOptions) SetComment(comment interface{}) *Estim eco.Comment = comment return eco } - -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option -// may be used in its place to control the amount of time that a single operation can run before -// returning an error. MaxTime is ignored if Timeout is set on the client. -func (eco *EstimatedDocumentCountOptions) SetMaxTime(d time.Duration) *EstimatedDocumentCountOptions { - eco.MaxTime = &d - return eco -} diff --git a/mongo/options/findoptions.go b/mongo/options/findoptions.go index 705fefc3f3a..e8c8fa4c607 100644 --- a/mongo/options/findoptions.go +++ b/mongo/options/findoptions.go @@ -58,14 +58,6 @@ type FindOptions struct { // MongoDB versions >= 3.2. For other cursor types or previous server versions, this option is ignored. MaxAwaitTime *time.Duration - // MaxTime is the maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used in its - // place to control the amount of time that a single operation can run before returning an error. MaxTime is ignored if - // Timeout is set on the client. - MaxTime *time.Duration - // Min is a document specifying the inclusive lower bound for a specific index. The default value is 0, which means that // there is no minimum value. Min interface{} @@ -171,16 +163,6 @@ func (f *FindOptions) SetMaxAwaitTime(d time.Duration) *FindOptions { return f } -// SetMaxTime specifies the max time to allow the query to run. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used used in its place to control the amount of time that a single operation -// can run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOptions) SetMaxTime(d time.Duration) *FindOptions { - f.MaxTime = &d - return f -} - // SetMin sets the value for the Min field. func (f *FindOptions) SetMin(min interface{}) *FindOptions { f.Min = min @@ -248,14 +230,6 @@ type FindOneOptions struct { // there is no maximum value. Max interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document specifying the inclusive lower bound for a specific index. The default value is 0, which means that // there is no minimum value. Min interface{} @@ -315,16 +289,6 @@ func (f *FindOneOptions) SetMax(max interface{}) *FindOneOptions { return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneOptions) SetMaxTime(d time.Duration) *FindOneOptions { - f.MaxTime = &d - return f -} - // SetMin sets the value for the Min field. func (f *FindOneOptions) SetMin(min interface{}) *FindOneOptions { f.Min = min @@ -378,14 +342,6 @@ type FindOneAndReplaceOptions struct { // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document describing which fields will be included in the document returned by the operation. The default value // is nil, which means all fields will be included. Projection interface{} @@ -441,16 +397,6 @@ func (f *FindOneAndReplaceOptions) SetComment(comment interface{}) *FindOneAndRe return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneAndReplaceOptions) SetMaxTime(d time.Duration) *FindOneAndReplaceOptions { - f.MaxTime = &d - return f -} - // SetProjection sets the value for the Projection field. func (f *FindOneAndReplaceOptions) SetProjection(projection interface{}) *FindOneAndReplaceOptions { f.Projection = projection @@ -509,14 +455,6 @@ type FindOneAndUpdateOptions struct { // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime is - // ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document describing which fields will be included in the document returned by the operation. The default value // is nil, which means all fields will be included. Projection interface{} @@ -578,16 +516,6 @@ func (f *FindOneAndUpdateOptions) SetComment(comment interface{}) *FindOneAndUpd return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneAndUpdateOptions) SetMaxTime(d time.Duration) *FindOneAndUpdateOptions { - f.MaxTime = &d - return f -} - // SetProjection sets the value for the Projection field. func (f *FindOneAndUpdateOptions) SetProjection(projection interface{}) *FindOneAndUpdateOptions { f.Projection = projection @@ -635,14 +563,6 @@ type FindOneAndDeleteOptions struct { // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document describing which fields will be included in the document returned by the operation. The default value // is nil, which means all fields will be included. Projection interface{} @@ -684,16 +604,6 @@ func (f *FindOneAndDeleteOptions) SetComment(comment interface{}) *FindOneAndDel return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneAndDeleteOptions) SetMaxTime(d time.Duration) *FindOneAndDeleteOptions { - f.MaxTime = &d - return f -} - // SetProjection sets the value for the Projection field. func (f *FindOneAndDeleteOptions) SetProjection(projection interface{}) *FindOneAndDeleteOptions { f.Projection = projection diff --git a/mongo/options/gridfsoptions.go b/mongo/options/gridfsoptions.go index 10d454c89d7..c8dcf447fc9 100644 --- a/mongo/options/gridfsoptions.go +++ b/mongo/options/gridfsoptions.go @@ -7,8 +7,6 @@ package options import ( - "time" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -155,14 +153,6 @@ type GridFSFindOptions struct { // batch. The default value is 0. Limit *int32 - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // If true, the cursor created by the operation will not timeout after a period of inactivity. The default value // is false. NoCursorTimeout *bool @@ -198,16 +188,6 @@ func (f *GridFSFindOptions) SetLimit(i int32) *GridFSFindOptions { return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *GridFSFindOptions) SetMaxTime(d time.Duration) *GridFSFindOptions { - f.MaxTime = &d - return f -} - // SetNoCursorTimeout sets the value for the NoCursorTimeout field. func (f *GridFSFindOptions) SetNoCursorTimeout(b bool) *GridFSFindOptions { f.NoCursorTimeout = &b diff --git a/mongo/options/indexoptions.go b/mongo/options/indexoptions.go index 1837b1037a9..82675d6a6a7 100644 --- a/mongo/options/indexoptions.go +++ b/mongo/options/indexoptions.go @@ -6,10 +6,6 @@ package options -import ( - "time" -) - // CreateIndexesOptions represents options that can be used to configure IndexView.CreateOne and IndexView.CreateMany // operations. type CreateIndexesOptions struct { @@ -26,14 +22,6 @@ type CreateIndexesOptions struct { // is specified for MongoDB versions <= 4.2. The default value is nil, meaning that the server-side default will be // used. See dochub.mongodb.org/core/index-commit-quorum for more information. CommitQuorum interface{} - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration } // CreateIndexes creates a new CreateIndexesOptions instance. @@ -41,16 +29,6 @@ func CreateIndexes() *CreateIndexesOptions { return &CreateIndexesOptions{} } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (c *CreateIndexesOptions) SetMaxTime(d time.Duration) *CreateIndexesOptions { - c.MaxTime = &d - return c -} - // SetCommitQuorumInt sets the value for the CommitQuorum field as an int32. func (c *CreateIndexesOptions) SetCommitQuorumInt(quorum int32) *CreateIndexesOptions { c.CommitQuorum = quorum @@ -77,43 +55,17 @@ func (c *CreateIndexesOptions) SetCommitQuorumVotingMembers() *CreateIndexesOpti // DropIndexesOptions represents options that can be used to configure IndexView.DropOne and IndexView.DropAll // operations. -type DropIndexesOptions struct { - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration -} +type DropIndexesOptions struct{} // DropIndexes creates a new DropIndexesOptions instance. func DropIndexes() *DropIndexesOptions { return &DropIndexesOptions{} } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (d *DropIndexesOptions) SetMaxTime(duration time.Duration) *DropIndexesOptions { - d.MaxTime = &duration - return d -} - // ListIndexesOptions represents options that can be used to configure an IndexView.List operation. type ListIndexesOptions struct { // The maximum number of documents to be included in each batch returned by the server. BatchSize *int32 - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration } // ListIndexes creates a new ListIndexesOptions instance. @@ -127,16 +79,6 @@ func (l *ListIndexesOptions) SetBatchSize(i int32) *ListIndexesOptions { return l } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (l *ListIndexesOptions) SetMaxTime(d time.Duration) *ListIndexesOptions { - l.MaxTime = &d - return l -} - // IndexOptions represents options that can be used to configure a new index created through the IndexView.CreateOne // or IndexView.CreateMany operations. type IndexOptions struct { diff --git a/mongo/options/sessionoptions.go b/mongo/options/sessionoptions.go index 4e1fdb11143..d83610b173f 100644 --- a/mongo/options/sessionoptions.go +++ b/mongo/options/sessionoptions.go @@ -7,8 +7,6 @@ package options import ( - "time" - "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -36,14 +34,6 @@ type SessionOptions struct { // the write concern of the client used to start the session will be used. DefaultWriteConcern *writeconcern.WriteConcern - // The default maximum amount of time that a CommitTransaction operation executed in the session can run on the - // server. The default value is nil, which means that that there is no time limit for execution. - // - // NOTE(benjirewis): DefaultMaxCommitTime will be deprecated in a future release. The more general Timeout option - // may be used in its place to control the amount of time that a single operation can run before returning an - // error. DefaultMaxCommitTime is ignored if Timeout is set on the client. - DefaultMaxCommitTime *time.Duration - // If true, all read operations performed with this session will be read from the same snapshot. This option cannot // be set to true if CausalConsistency is set to true. Transactions and write operations are not allowed on // snapshot sessions and will error. The default value is false. @@ -79,17 +69,6 @@ func (s *SessionOptions) SetDefaultWriteConcern(wc *writeconcern.WriteConcern) * return s } -// SetDefaultMaxCommitTime sets the value for the DefaultMaxCommitTime field. -// -// NOTE(benjirewis): DefaultMaxCommitTime will be deprecated in a future release. The more -// general Timeout option may be used in its place to control the amount of time that a -// single operation can run before returning an error. DefaultMaxCommitTime is ignored if -// Timeout is set on the client. -func (s *SessionOptions) SetDefaultMaxCommitTime(mct *time.Duration) *SessionOptions { - s.DefaultMaxCommitTime = mct - return s -} - // SetSnapshot sets the value for the Snapshot field. func (s *SessionOptions) SetSnapshot(b bool) *SessionOptions { s.Snapshot = &b diff --git a/mongo/options/transactionoptions.go b/mongo/options/transactionoptions.go index 2bc4c2166c0..c346b0f63ff 100644 --- a/mongo/options/transactionoptions.go +++ b/mongo/options/transactionoptions.go @@ -7,8 +7,6 @@ package options import ( - "time" - "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -27,18 +25,6 @@ type TransactionOptions struct { // The write concern for operations in the transaction. The default value is nil, which means that the default // write concern of the session used to start the transaction will be used. WriteConcern *writeconcern.WriteConcern - - // The default maximum amount of time that a CommitTransaction operation executed in the session can run on the - // server. The default value is nil, meaning that there is no time limit for execution. - - // The maximum amount of time that a CommitTransaction operation can executed in the transaction can run on the - // server. The default value is nil, which means that the default maximum commit time of the session used to - // start the transaction will be used. - // - // NOTE(benjirewis): MaxCommitTime will be deprecated in a future release. The more general Timeout option may - // be used in its place to control the amount of time that a single operation can run before returning an error. - // MaxCommitTime is ignored if Timeout is set on the client. - MaxCommitTime *time.Duration } // Transaction creates a new TransactionOptions instance. @@ -63,13 +49,3 @@ func (t *TransactionOptions) SetWriteConcern(wc *writeconcern.WriteConcern) *Tra t.WriteConcern = wc return t } - -// SetMaxCommitTime sets the value for the MaxCommitTime field. -// -// NOTE(benjirewis): MaxCommitTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can run before -// returning an error. MaxCommitTime is ignored if Timeout is set on the client. -func (t *TransactionOptions) SetMaxCommitTime(mct *time.Duration) *TransactionOptions { - t.MaxCommitTime = mct - return t -} diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index ec49bb91dbe..ffa41039b16 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -13,7 +13,6 @@ import ( "path" "reflect" "testing" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" @@ -48,6 +47,7 @@ type connectionStringTest struct { Valid bool `bson:"valid"` ReadConcern bson.Raw `bson:"readConcern"` WriteConcern bson.Raw `bson:"writeConcern"` + SkipReason string `bson:"skipReason"` } type documentTestFile struct { @@ -98,6 +98,10 @@ func runConnectionStringTestFile(t *testing.T, filePath string) { } func runConnectionStringTest(t *testing.T, test connectionStringTest) { + if test.SkipReason != "" { + t.Skip(test.SkipReason) + } + cs, err := connstring.ParseAndValidate(test.URI) if !test.Valid { assert.NotNil(t, err, "expected Parse error, got nil") @@ -122,11 +126,6 @@ func runConnectionStringTest(t *testing.T, test connectionStringTest) { assert.Equal(t, expected, cs.WString, "expected w value %v, got %v", expected, cs.WString) } } - if expectedWc.timeoutSet { - assert.True(t, cs.WTimeoutSet, "expected WTimeoutSet, got false") - assert.Equal(t, expectedWc.WTimeout, cs.WTimeout, - "expected timeout value %v, got %v", expectedWc.WTimeout, cs.WTimeout) - } if expectedWc.jSet { assert.True(t, cs.JSet, "expected JSet, got false") assert.Equal(t, *expectedWc.Journal, cs.J, "expected j value %v, got %v", *expectedWc.Journal, cs.J) @@ -221,9 +220,8 @@ func readConcernFromRaw(t *testing.T, rc bson.Raw) *readconcern.ReadConcern { type writeConcern struct { *writeconcern.WriteConcern - jSet bool - wSet bool - timeoutSet bool + jSet bool + wSet bool } func writeConcernFromRaw(t *testing.T, wcRaw bson.Raw) writeConcern { @@ -247,14 +245,12 @@ func writeConcernFromRaw(t *testing.T, wcRaw bson.Raw) writeConcern { default: t.Fatalf("unexpected type for w: %v", val.Type) } - case "wtimeoutMS": - wc.timeoutSet = true - timeout := time.Duration(val.Int32()) * time.Millisecond - wc.WriteConcern.WTimeout = timeout case "journal": wc.jSet = true j := val.Boolean() wc.WriteConcern.Journal = &j + case "wtimeoutMS": // Do nothing, this field is deprecated + t.Skip("the wtimeoutMS write concern option is not supported") default: t.Fatalf("unrecognized write concern field: %v", key) } diff --git a/mongo/session.go b/mongo/session.go index 778abebc633..5df2d800f42 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -214,15 +214,11 @@ func (s *Session) StartTransaction(opts ...*options.TransactionOptions) error { if opt.WriteConcern != nil { topts.WriteConcern = opt.WriteConcern } - if opt.MaxCommitTime != nil { - topts.MaxCommitTime = opt.MaxCommitTime - } } coreOpts := &session.TransactionOptions{ ReadConcern: topts.ReadConcern, ReadPreference: topts.ReadPreference, WriteConcern: topts.WriteConcern, - MaxCommitTime: topts.MaxCommitTime, } return s.clientSession.StartTransaction(coreOpts) @@ -282,7 +278,7 @@ func (s *Session) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct) + ServerAPI(s.client.serverAPI) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index d737ff9a072..e1880fc7d2d 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -399,19 +399,23 @@ func TestConvenientTransactions(t *testing.T) { // Insert a document within a session and manually cancel context before // "commitTransaction" can be sent. - callback := func(ctx context.Context) { - transactionCtx, cancel := context.WithCancel(ctx) - + callback := func() bool { + transactionCtx, cancel := context.WithCancel(context.Background()) _, _ = sess.WithTransaction(transactionCtx, func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.M{"x": 1}) - assert.Nil(t, err, "InsertOne error: %v", err) + assert.NoError(t, err, "InsertOne error: %v", err) cancel() return nil, nil }) + return true } // Assert that transaction is canceled within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to be canceled within 500ms") // Assert that AbortTransaction was started once and succeeded. assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted)) @@ -459,19 +463,24 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { + callback := func() bool { // Create transaction context with short timeout. - withTransactionContext, cancel := context.WithTimeout(ctx, time.Nanosecond) + withTransactionContext, cancel := context.WithTimeout(context.Background(), time.Nanosecond) defer cancel() _, _ = sess.WithTransaction(withTransactionContext, func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) + return true } // Assert that transaction fails within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to fail within 500ms") }) t.Run("canceled context before callback does not retry", func(t *testing.T) { withTransactionTimeout = 2 * time.Second @@ -489,19 +498,24 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { + callback := func() bool { // Create transaction context and cancel it immediately. - withTransactionContext, cancel := context.WithTimeout(ctx, 2*time.Second) + withTransactionContext, cancel := context.WithTimeout(context.Background(), 2*time.Second) cancel() _, _ = sess.WithTransaction(withTransactionContext, func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) + return true } // Assert that transaction fails within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to fail within 500ms") }) t.Run("slow operation in callback retries", func(t *testing.T) { withTransactionTimeout = 2 * time.Second @@ -540,8 +554,8 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { - _, err = sess.WithTransaction(ctx, func(ctx context.Context) (interface{}, error) { + callback := func() bool { + _, err = sess.WithTransaction(context.Background(), func(ctx context.Context) (interface{}, error) { // Set a timeout of 300ms to cause a timeout on first insertOne // and force a retry. c, cancel := context.WithTimeout(ctx, 300*time.Millisecond) @@ -550,11 +564,17 @@ func TestConvenientTransactions(t *testing.T) { _, err := coll.InsertOne(c, bson.D{{}}) return nil, err }) - assert.Nil(t, err, "WithTransaction error: %v", err) + assert.NoError(t, err, "WithTransaction error: %v", err) + return true } // Assert that transaction passes within 2 seconds. - assert.Soon(t, callback, 2*time.Second) + assert.Eventually(t, + callback, + withTransactionTimeout, + time.Millisecond, + "expected transaction to be passed within 2s") + }) } diff --git a/mongo/writeconcern/writeconcern.go b/mongo/writeconcern/writeconcern.go index 2e4d2ade16f..c8398c7f15c 100644 --- a/mongo/writeconcern/writeconcern.go +++ b/mongo/writeconcern/writeconcern.go @@ -13,7 +13,7 @@ package writeconcern import ( "errors" "fmt" - "time" + "math" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -74,17 +74,6 @@ type WriteConcern struct { // For more information about the "j" option, see // https://www.mongodb.com/docs/manual/reference/write-concern/#j-option Journal *bool - - // WTimeout specifies a time limit for the write concern. It sets the - // "wtimeout" option in a MongoDB write concern. - // - // It is only applicable for "w" values greater than 1. Using a WTimeout and - // setting Timeout on the Client at the same time will result in undefined - // behavior. - // - // For more information about the "wtimeout" option, see - // https://www.mongodb.com/docs/manual/reference/write-concern/#wtimeout - WTimeout time.Duration } // Unacknowledged returns a WriteConcern that requests no acknowledgment of @@ -169,6 +158,9 @@ func (wc *WriteConcern) MarshalBSONValue() (bson.Type, []byte, error) { return 0, nil, ErrInconsistent } + if w > math.MaxInt32 { + return 0, nil, fmt.Errorf("%d overflows int32", w) + } elems = bsoncore.AppendInt32Element(elems, "w", int32(w)) case string: elems = bsoncore.AppendStringElement(elems, "w", w) @@ -183,14 +175,6 @@ func (wc *WriteConcern) MarshalBSONValue() (bson.Type, []byte, error) { elems = bsoncore.AppendBooleanElement(elems, "j", *wc.Journal) } - if wc.WTimeout < 0 { - return 0, nil, ErrNegativeWTimeout - } - - if wc.WTimeout != 0 { - elems = bsoncore.AppendInt64Element(elems, "wtimeout", int64(wc.WTimeout/time.Millisecond)) - } - if len(elems) == 0 { return 0, nil, ErrEmptyWriteConcern } diff --git a/mongo/writeconcern/writeconcern_test.go b/mongo/writeconcern/writeconcern_test.go index b3486fe4f9b..07f7b9c3aee 100644 --- a/mongo/writeconcern/writeconcern_test.go +++ b/mongo/writeconcern/writeconcern_test.go @@ -7,118 +7,12 @@ package writeconcern_test import ( - "errors" "testing" - "time" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/writeconcern" ) -func TestWriteConcern_MarshalBSONValue(t *testing.T) { - t.Parallel() - - boolPtr := func(b bool) *bool { return &b } - - testCases := []struct { - name string - wc *writeconcern.WriteConcern - wantType bson.Type - wantValue bson.D - wantError error - }{ - { - name: "all fields", - wc: &writeconcern.WriteConcern{ - W: "majority", - Journal: boolPtr(false), - WTimeout: 1 * time.Minute, - }, - wantType: bson.TypeEmbeddedDocument, - wantValue: bson.D{ - {Key: "w", Value: "majority"}, - {Key: "j", Value: false}, - {Key: "wtimeout", Value: int64(60_000)}, - }, - }, - { - name: "string W", - wc: &writeconcern.WriteConcern{W: "majority"}, - wantType: bson.TypeEmbeddedDocument, - wantValue: bson.D{{Key: "w", Value: "majority"}}, - }, - { - name: "int W", - wc: &writeconcern.WriteConcern{W: 1}, - wantType: bson.TypeEmbeddedDocument, - wantValue: bson.D{{Key: "w", Value: int32(1)}}, - }, - { - name: "int32 W", - wc: &writeconcern.WriteConcern{W: int32(1)}, - wantError: errors.New("WriteConcern.W must be a string or int, but is a int32"), - }, - { - name: "bool W", - wc: &writeconcern.WriteConcern{W: false}, - wantError: errors.New("WriteConcern.W must be a string or int, but is a bool"), - }, - { - name: "W=0 and J=true", - wc: &writeconcern.WriteConcern{W: 0, Journal: boolPtr(true)}, - wantError: writeconcern.ErrInconsistent, - }, - { - name: "negative W", - wc: &writeconcern.WriteConcern{W: -1}, - wantError: writeconcern.ErrNegativeW, - }, - { - name: "negative WTimeout", - wc: &writeconcern.WriteConcern{W: 1, WTimeout: -1}, - wantError: writeconcern.ErrNegativeWTimeout, - }, - { - name: "empty", - wc: &writeconcern.WriteConcern{}, - wantError: writeconcern.ErrEmptyWriteConcern, - }, - { - name: "nil", - wc: nil, - wantError: writeconcern.ErrEmptyWriteConcern, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - typ, b, err := tc.wc.MarshalBSONValue() - if tc.wantError != nil { - assert.Equal(t, tc.wantError, err, "expected and actual errors do not match") - return - } - require.NoError(t, err, "bson.MarshalValue error") - - assert.Equal(t, tc.wantType, typ, "expected and actual BSON types do not match") - - rv := bson.RawValue{ - Type: typ, - Value: b, - } - var gotValue bson.D - err = rv.Unmarshal(&gotValue) - require.NoError(t, err, "error unmarshaling RawValue") - assert.Equal(t, tc.wantValue, gotValue, "expected and actual BSON values do not match") - }) - } -} - func TestWriteConcern(t *testing.T) { boolPtr := func(b bool) *bool { return &b } diff --git a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json index 6cf1f4ce6e1..aded781aeed 100644 --- a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json +++ b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json @@ -6,7 +6,7 @@ "minServerVersion": "4.4", "topologies": [ "replicaset", - "sharded-replicaset" + "sharded" ] } ], @@ -73,7 +73,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -132,7 +132,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -194,7 +194,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -255,7 +255,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -319,7 +319,7 @@ "delete" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -376,7 +376,7 @@ "delete" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -436,7 +436,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -496,7 +496,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -559,7 +559,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -621,7 +621,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -686,7 +686,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -743,7 +743,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -803,7 +803,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -863,7 +863,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -926,7 +926,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -988,7 +988,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1053,7 +1053,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1118,7 +1118,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1186,7 +1186,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1243,7 +1243,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1303,7 +1303,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1357,7 +1357,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1414,7 +1414,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1471,7 +1471,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1531,7 +1531,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1595,7 +1595,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1662,7 +1662,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1719,7 +1719,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1779,7 +1779,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1836,7 +1836,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1896,7 +1896,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1953,7 +1953,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2013,7 +2013,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2070,7 +2070,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2130,7 +2130,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2187,7 +2187,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2247,7 +2247,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2304,7 +2304,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2364,7 +2364,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2418,7 +2418,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2475,7 +2475,7 @@ "distinct" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2533,7 +2533,7 @@ "distinct" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2594,7 +2594,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2651,7 +2651,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2711,7 +2711,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2768,7 +2768,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2828,7 +2828,7 @@ "listIndexes" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2882,7 +2882,7 @@ "listIndexes" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2939,7 +2939,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2996,7 +2996,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } diff --git a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml index de3eb9971d0..8ada5fb7917 100644 --- a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml +++ b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml @@ -6,7 +6,7 @@ schemaVersion: "1.9" runOnRequirements: - minServerVersion: "4.4" - topologies: ["replicaset", "sharded-replicaset"] + topologies: ["replicaset", "sharded"] createEntities: - client: @@ -38,8 +38,8 @@ initialData: tests: # For each retryable operation, run two tests: # - # 1. Socket timeouts are retried once - Each test constructs a client entity with socketTimeoutMS=50, configures a - # fail point to block the operation once for 110ms, and expects the operation to succeed. + # 1. Socket timeouts are retried once - Each test constructs a client entity with socketTimeoutMS=100, configures a + # fail point to block the operation once for 125ms, and expects the operation to succeed. # # 2. Operations fail after two consecutive socket timeouts - Same as (1) but the fail point is configured to block # the operation twice and the test expects the operation to fail. @@ -56,7 +56,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertOne object: *collection arguments: @@ -87,7 +87,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertOne object: *collection arguments: @@ -121,7 +121,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertMany object: *collection arguments: @@ -153,7 +153,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertMany object: *collection arguments: @@ -188,7 +188,7 @@ tests: data: failCommands: ["delete"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: deleteOne object: *collection arguments: @@ -219,7 +219,7 @@ tests: data: failCommands: ["delete"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: deleteOne object: *collection arguments: @@ -253,7 +253,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: replaceOne object: *collection arguments: @@ -285,7 +285,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: replaceOne object: *collection arguments: @@ -320,7 +320,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: updateOne object: *collection arguments: @@ -352,7 +352,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: updateOne object: *collection arguments: @@ -387,7 +387,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndDelete object: *collection arguments: @@ -418,7 +418,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndDelete object: *collection arguments: @@ -452,7 +452,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndReplace object: *collection arguments: @@ -484,7 +484,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndReplace object: *collection arguments: @@ -519,7 +519,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndUpdate object: *collection arguments: @@ -551,7 +551,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndUpdate object: *collection arguments: @@ -586,7 +586,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: bulkWrite object: *collection arguments: @@ -619,7 +619,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: bulkWrite object: *collection arguments: @@ -655,7 +655,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabases object: *client arguments: @@ -686,7 +686,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabases object: *client arguments: @@ -720,7 +720,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabaseNames object: *client @@ -749,7 +749,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabaseNames object: *client @@ -781,7 +781,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *client arguments: @@ -812,7 +812,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *client arguments: @@ -846,7 +846,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *database arguments: @@ -877,7 +877,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *database arguments: @@ -911,7 +911,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollections object: *database arguments: @@ -942,7 +942,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollections object: *database arguments: @@ -976,7 +976,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollectionNames object: *database arguments: @@ -1007,7 +1007,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollectionNames object: *database arguments: @@ -1041,7 +1041,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *database arguments: @@ -1072,7 +1072,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *database arguments: @@ -1106,7 +1106,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *collection arguments: @@ -1137,7 +1137,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *collection arguments: @@ -1171,7 +1171,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: count object: *collection arguments: @@ -1202,7 +1202,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: count object: *collection arguments: @@ -1236,7 +1236,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: countDocuments object: *collection arguments: @@ -1267,7 +1267,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: countDocuments object: *collection arguments: @@ -1301,7 +1301,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: estimatedDocumentCount object: *collection @@ -1330,7 +1330,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: estimatedDocumentCount object: *collection @@ -1362,7 +1362,7 @@ tests: data: failCommands: ["distinct"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: distinct object: *collection arguments: @@ -1394,7 +1394,7 @@ tests: data: failCommands: ["distinct"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: distinct object: *collection arguments: @@ -1429,7 +1429,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: find object: *collection arguments: @@ -1460,7 +1460,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: find object: *collection arguments: @@ -1494,7 +1494,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOne object: *collection arguments: @@ -1525,7 +1525,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOne object: *collection arguments: @@ -1559,7 +1559,7 @@ tests: data: failCommands: ["listIndexes"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listIndexes object: *collection @@ -1588,7 +1588,7 @@ tests: data: failCommands: ["listIndexes"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listIndexes object: *collection @@ -1620,7 +1620,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *collection arguments: @@ -1651,7 +1651,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *collection arguments: diff --git a/testdata/convenient-transactions/commit-retry.json b/testdata/convenient-transactions/commit-retry.json index 02e38460d05..6257e99345b 100644 --- a/testdata/convenient-transactions/commit-retry.json +++ b/testdata/convenient-transactions/commit-retry.json @@ -150,6 +150,7 @@ }, { "description": "commitTransaction retry only overwrites write concern w option", + "skipReason": "GODRIVER-2348: wtimeout is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { @@ -429,6 +430,7 @@ }, { "description": "commit is not retried after MaxTimeMSExpired error", + "skipReason": "GODRIVER-2348: maxTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { diff --git a/testdata/convenient-transactions/commit-retry.yml b/testdata/convenient-transactions/commit-retry.yml index 74c03dd9fbd..3ff6497ae40 100644 --- a/testdata/convenient-transactions/commit-retry.yml +++ b/testdata/convenient-transactions/commit-retry.yml @@ -99,6 +99,7 @@ tests: - { _id: 1 } - description: commitTransaction retry only overwrites write concern w option + skipReason: "GODRIVER-2348: wtimeout is deprecated" failPoint: configureFailPoint: failCommand mode: { times: 2 } @@ -260,6 +261,7 @@ tests: - { _id: 1 } - description: commit is not retried after MaxTimeMSExpired error + skipReason: "GODRIVER-2348: maxTimeMS is deprecated" failPoint: configureFailPoint: failCommand mode: { times: 1 } diff --git a/testdata/read-write-concern/connection-string/write-concern.json b/testdata/read-write-concern/connection-string/write-concern.json index 51bdf821c34..a81e297dae9 100644 --- a/testdata/read-write-concern/connection-string/write-concern.json +++ b/testdata/read-write-concern/connection-string/write-concern.json @@ -33,6 +33,7 @@ }, { "description": "wtimeoutMS as a valid number", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?wtimeoutMS=500", "valid": true, "warning": false, @@ -42,6 +43,7 @@ }, { "description": "wtimeoutMS as an invalid number", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?wtimeoutMS=-500", "valid": false, "warning": null @@ -66,6 +68,7 @@ }, { "description": "All options combined", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?w=3&wtimeoutMS=500&journal=true", "valid": true, "warning": false, @@ -96,6 +99,7 @@ }, { "description": "Unacknowledged with w and wtimeoutMS", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?w=0&wtimeoutMS=500", "valid": true, "warning": false, diff --git a/testdata/read-write-concern/connection-string/write-concern.yml b/testdata/read-write-concern/connection-string/write-concern.yml index ca610858651..52e09170e87 100644 --- a/testdata/read-write-concern/connection-string/write-concern.yml +++ b/testdata/read-write-concern/connection-string/write-concern.yml @@ -24,12 +24,14 @@ tests: writeConcern: { w: "majority" } - description: "wtimeoutMS as a valid number" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?wtimeoutMS=500" valid: true warning: false writeConcern: { wtimeoutMS: 500 } - description: "wtimeoutMS as an invalid number" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?wtimeoutMS=-500" valid: false warning: ~ @@ -47,6 +49,7 @@ tests: writeConcern: { journal: true } - description: "All options combined" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?w=3&wtimeoutMS=500&journal=true" valid: true warning: false @@ -65,6 +68,7 @@ tests: writeConcern: { w: 0, journal: false } - description: "Unacknowledged with w and wtimeoutMS" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?w=0&wtimeoutMS=500" valid: true warning: false diff --git a/testdata/read-write-concern/document/write-concern.json b/testdata/read-write-concern/document/write-concern.json index 64cd5d0eae2..fe81741e700 100644 --- a/testdata/read-write-concern/document/write-concern.json +++ b/testdata/read-write-concern/document/write-concern.json @@ -56,6 +56,7 @@ }, { "description": "WTimeoutMS", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": true, "writeConcern": { "wtimeoutMS": 1000 @@ -68,6 +69,7 @@ }, { "description": "WTimeoutMS as an invalid number", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": false, "writeConcern": { "wtimeoutMS": -1000 @@ -114,6 +116,7 @@ }, { "description": "Unacknowledged with wtimeoutMS", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": true, "writeConcern": { "w": 0, @@ -156,6 +159,7 @@ }, { "description": "Everything", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": true, "writeConcern": { "w": 3, diff --git a/testdata/read-write-concern/document/write-concern.yml b/testdata/read-write-concern/document/write-concern.yml index bd82fdd59d7..0c31f6958b3 100644 --- a/testdata/read-write-concern/document/write-concern.yml +++ b/testdata/read-write-concern/document/write-concern.yml @@ -36,6 +36,7 @@ tests: isAcknowledged: true - description: "WTimeoutMS" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" valid: true writeConcern: { wtimeoutMS: 1000 } writeConcernDocument: { wtimeout: 1000 } @@ -43,6 +44,7 @@ tests: isAcknowledged: true - description: "WTimeoutMS as an invalid number" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" valid: false writeConcern: { wtimeoutMS: -1000 } writeConcernDocument: ~ @@ -71,6 +73,7 @@ tests: isAcknowledged: false - description: "Unacknowledged with wtimeoutMS" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" valid: true writeConcern: { w: 0, wtimeoutMS: 500 } writeConcernDocument: { w: 0, wtimeout: 500 } diff --git a/testdata/transactions/legacy/error-labels.json b/testdata/transactions/legacy/error-labels.json index a57f216b9b4..8bb5af7700b 100644 --- a/testdata/transactions/legacy/error-labels.json +++ b/testdata/transactions/legacy/error-labels.json @@ -1687,6 +1687,7 @@ }, { "description": "do not add UnknownTransactionCommitResult label to MaxTimeMSExpired inside transactions", + "skipReason": "GODRIVER-2348: maxTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { @@ -1817,6 +1818,7 @@ }, { "description": "add UnknownTransactionCommitResult label to MaxTimeMSExpired", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { @@ -1949,6 +1951,7 @@ }, { "description": "add UnknownTransactionCommitResult label to writeConcernError MaxTimeMSExpired", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { diff --git a/testdata/transactions/legacy/error-labels.yml b/testdata/transactions/legacy/error-labels.yml index 5f2c7085c1f..d9c461eadf7 100644 --- a/testdata/transactions/legacy/error-labels.yml +++ b/testdata/transactions/legacy/error-labels.yml @@ -1029,6 +1029,7 @@ tests: - _id: 1 - description: do not add UnknownTransactionCommitResult label to MaxTimeMSExpired inside transactions + skipReason: "GODRIVER-2348: maxTimeMS is deprecated" failPoint: configureFailPoint: failCommand @@ -1109,6 +1110,7 @@ tests: data: [] - description: add UnknownTransactionCommitResult label to MaxTimeMSExpired + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" failPoint: configureFailPoint: failCommand @@ -1190,6 +1192,7 @@ tests: - _id: 1 - description: add UnknownTransactionCommitResult label to writeConcernError MaxTimeMSExpired + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" failPoint: configureFailPoint: failCommand diff --git a/testdata/transactions/legacy/retryable-commit.json b/testdata/transactions/legacy/retryable-commit.json index d83a1d9f52a..dde17146031 100644 --- a/testdata/transactions/legacy/retryable-commit.json +++ b/testdata/transactions/legacy/retryable-commit.json @@ -161,6 +161,7 @@ }, { "description": "commitTransaction applies majority write concern on retries", + "skipReason": "GODRIVER-2348: wtimeout is deprecated", "clientOptions": { "retryWrites": false }, diff --git a/testdata/transactions/legacy/retryable-commit.yml b/testdata/transactions/legacy/retryable-commit.yml index 8e0037f28ea..f48b53f1b93 100644 --- a/testdata/transactions/legacy/retryable-commit.yml +++ b/testdata/transactions/legacy/retryable-commit.yml @@ -102,6 +102,7 @@ tests: - _id: 1 - description: commitTransaction applies majority write concern on retries + skipReason: "GODRIVER-2348: wtimeout is deprecated" clientOptions: retryWrites: false diff --git a/testdata/transactions/legacy/transaction-options.json b/testdata/transactions/legacy/transaction-options.json index 25d245dca56..d474e3773dd 100644 --- a/testdata/transactions/legacy/transaction-options.json +++ b/testdata/transactions/legacy/transaction-options.json @@ -318,6 +318,7 @@ }, { "description": "transaction options inherited from defaultTransactionOptions", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "sessionOptions": { "session0": { "defaultTransactionOptions": { @@ -479,6 +480,7 @@ }, { "description": "startTransaction options override defaults", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "clientOptions": { "readConcernLevel": "local", "w": 1 @@ -668,6 +670,7 @@ }, { "description": "defaultTransactionOptions override client options", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "clientOptions": { "readConcernLevel": "local", "w": 1 diff --git a/testdata/transactions/legacy/transaction-options.yml b/testdata/transactions/legacy/transaction-options.yml index 461e87d55f6..314e0284a6e 100644 --- a/testdata/transactions/legacy/transaction-options.yml +++ b/testdata/transactions/legacy/transaction-options.yml @@ -260,6 +260,7 @@ tests: outcome: *outcome - description: startTransaction options override defaults + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" clientOptions: readConcernLevel: local @@ -381,6 +382,7 @@ tests: outcome: *outcome - description: defaultTransactionOptions override client options + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" clientOptions: readConcernLevel: local @@ -665,6 +667,7 @@ tests: - _id: 1 - description: readPreference inherited from defaultTransactionOptions + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" clientOptions: readPreference: primary diff --git a/x/bsonx/bsoncore/bsoncore.go b/x/bsonx/bsoncore/bsoncore.go index af373e5cf13..13219ee5473 100644 --- a/x/bsonx/bsoncore/bsoncore.go +++ b/x/bsonx/bsoncore/bsoncore.go @@ -8,6 +8,7 @@ package bsoncore import ( "bytes" + "encoding/binary" "fmt" "math" "strconv" @@ -706,17 +707,16 @@ func ReserveLength(dst []byte) (int32, []byte) { // UpdateLength updates the length at index with length and returns the []byte. func UpdateLength(dst []byte, index, length int32) []byte { - dst[index] = byte(length) - dst[index+1] = byte(length >> 8) - dst[index+2] = byte(length >> 16) - dst[index+3] = byte(length >> 24) + binary.LittleEndian.PutUint32(dst[index:], uint32(length)) return dst } func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) } func appendi32(dst []byte, i32 int32) []byte { - return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24)) + b := []byte{0, 0, 0, 0} + binary.LittleEndian.PutUint32(b, uint32(i32)) + return append(dst, b...) } // ReadLength reads an int32 length from src and returns the length and the remaining bytes. If @@ -734,27 +734,26 @@ func readi32(src []byte) (int32, []byte, bool) { if len(src) < 4 { return 0, src, false } - return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true + return int32(binary.LittleEndian.Uint32(src)), src[4:], true } func appendi64(dst []byte, i64 int64) []byte { - return append(dst, - byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24), - byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56), - ) + b := []byte{0, 0, 0, 0, 0, 0, 0, 0} + binary.LittleEndian.PutUint64(b, uint64(i64)) + return append(dst, b...) } func readi64(src []byte) (int64, []byte, bool) { if len(src) < 8 { return 0, src, false } - i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 | - int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56) - return i64, src[8:], true + return int64(binary.LittleEndian.Uint64(src)), src[8:], true } func appendu32(dst []byte, u32 uint32) []byte { - return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24)) + b := []byte{0, 0, 0, 0} + binary.LittleEndian.PutUint32(b, u32) + return append(dst, b...) } func readu32(src []byte) (uint32, []byte, bool) { @@ -762,23 +761,20 @@ func readu32(src []byte) (uint32, []byte, bool) { return 0, src, false } - return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true + return binary.LittleEndian.Uint32(src), src[4:], true } func appendu64(dst []byte, u64 uint64) []byte { - return append(dst, - byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24), - byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56), - ) + b := []byte{0, 0, 0, 0, 0, 0, 0, 0} + binary.LittleEndian.PutUint64(b, u64) + return append(dst, b...) } func readu64(src []byte) (uint64, []byte, bool) { if len(src) < 8 { return 0, src, false } - u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 | - uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56) - return u64, src[8:], true + return binary.LittleEndian.Uint64(src), src[8:], true } // keep in sync with readcstringbytes diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index f78ef652fea..67160169240 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -45,7 +45,6 @@ type BatchCursor struct { errorProcessor ErrorProcessor // This will only be set when pinning to a connection. connection *mnet.Connection batchSize int32 - maxTimeMS int64 currentBatch *bsoncore.Iterator firstBatch bool cmdMonitor *event.CommandMonitor @@ -53,6 +52,10 @@ type BatchCursor struct { crypt Crypt serverAPI *ServerAPIOptions + // maxAwaitTime is only valid for tailable awaitData cursors. If this option + // is set, it will be used as the "maxTimeMS" field on getMore commands. + maxAwaitTime *time.Duration + // legacy server (< 3.2) fields limit int32 numReturned int32 // number of docs returned by server @@ -157,12 +160,21 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { type CursorOptions struct { BatchSize int32 Comment bsoncore.Value - MaxTimeMS int64 Limit int32 CommandMonitor *event.CommandMonitor Crypt Crypt ServerAPI *ServerAPIOptions MarshalValueEncoderFn func(io.Writer) (*bson.Encoder, error) + + // MaxAwaitTime is only valid for tailable awaitData cursors. If this option + // is set, it will be used as the "maxTimeMS" field on getMore commands. + MaxAwaitTime *time.Duration +} + +// SetMaxAwaitTime will set the maxTimeMS value on getMore commands for +// tailable awaitData cursors. +func (cursorOptions *CursorOptions) SetMaxAwaitTime(dur time.Duration) { + cursorOptions.MaxAwaitTime = &dur } // NewBatchCursor creates a new BatchCursor from the provided parameters. @@ -185,7 +197,7 @@ func NewBatchCursor( connection: cr.Connection, errorProcessor: cr.ErrorProcessor, batchSize: opts.BatchSize, - maxTimeMS: opts.MaxTimeMS, + maxAwaitTime: opts.MaxAwaitTime, cmdMonitor: opts.CommandMonitor, firstBatch: true, postBatchResumeToken: cr.postBatchResumeToken, @@ -363,14 +375,15 @@ func (bc *BatchCursor) getMore(ctx context.Context) { } bc.err = Operation{ - CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id) dst = bsoncore.AppendStringElement(dst, "collection", bc.collection) if numToReturn > 0 { dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn) } - if bc.maxTimeMS > 0 { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", bc.maxTimeMS) + + if bc.maxAwaitTime != nil && *bc.maxAwaitTime > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(*bc.maxAwaitTime)/int64(time.Millisecond)) } comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn) @@ -471,14 +484,14 @@ func (bc *BatchCursor) SetBatchSize(size int32) { bc.batchSize = size } -// SetMaxTime will set the maximum amount of time the server will allow the +// SetMaxAwaitTime will set the maximum amount of time the server will allow the // operations to execute. The server will error if this field is set but the // cursor is not configured with awaitData=true. // // The time.Duration value passed by this setter will be converted and rounded // down to the nearest millisecond. -func (bc *BatchCursor) SetMaxTime(dur time.Duration) { - bc.maxTimeMS = int64(dur / time.Millisecond) +func (bc *BatchCursor) SetMaxAwaitTime(dur time.Duration) { + bc.maxAwaitTime = &dur } // SetComment sets the comment for future getMore operations. @@ -509,7 +522,7 @@ var _ Deployment = (*loadBalancedCursorDeployment)(nil) var _ Server = (*loadBalancedCursorDeployment)(nil) var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) -func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ description.ServerSelector) (Server, error) { +func (lbcd *loadBalancedCursorDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) { return lbcd, nil } @@ -529,3 +542,9 @@ func (lbcd *loadBalancedCursorDeployment) RTTMonitor() RTTMonitor { func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, desc mnet.Describer) ProcessErrorResult { return lbcd.errorProcessor.ProcessError(err, desc) } + +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for load-balanced cursor deployments. +func (*loadBalancedCursorDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} diff --git a/x/mongo/driver/batch_cursor_test.go b/x/mongo/driver/batch_cursor_test.go index c57434cb83f..7c9ad38c7b5 100644 --- a/x/mongo/driver/batch_cursor_test.go +++ b/x/mongo/driver/batch_cursor_test.go @@ -8,7 +8,6 @@ package driver import ( "testing" - "time" "go.mongodb.org/mongo-driver/internal/assert" ) @@ -91,43 +90,3 @@ func TestBatchCursor(t *testing.T) { } }) } - -func TestBatchCursorSetMaxTime(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - dur time.Duration - want int64 - }{ - { - name: "empty", - dur: 0, - want: 0, - }, - { - name: "partial milliseconds are truncated", - dur: 10_900 * time.Microsecond, - want: 10, - }, - { - name: "millisecond input", - dur: 10 * time.Millisecond, - want: 10, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - bc := BatchCursor{} - bc.SetMaxTime(test.dur) - - got := bc.maxTimeMS - assert.Equal(t, test.want, got, "expected and actual maxTimeMS are different") - }) - } -} diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index d79b024b74d..d9a6c68feed 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -30,7 +30,11 @@ type CompressionOpts struct { // destination writer. It panics on any errors and should only be used at // package initialization time. func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { - enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) + enc, err := zstd.NewWriter( + nil, + zstd.WithWindowSize(8<<20), // Set window size to 8MB. + zstd.WithEncoderLevel(lvl), + ) if err != nil { panic(err) } @@ -105,6 +109,13 @@ func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { return dst, nil } +var zstdBufPool = sync.Pool{ + New: func() interface{} { + s := make([]byte, 0) + return &s + }, +} + // CompressPayload takes a byte slice and compresses it according to the options passed func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { switch opts.Compressor { @@ -123,7 +134,13 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { if err != nil { return nil, err } - return encoder.EncodeAll(in, nil), nil + ptr := zstdBufPool.Get().(*[]byte) + b := encoder.EncodeAll(in, *ptr) + dst := make([]byte, len(b)) + copy(dst, b) + *ptr = b[:0] + zstdBufPool.Put(ptr) + return dst, nil default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 43fae2fb1a2..4465afe3be0 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -187,10 +187,6 @@ type ConnString struct { ZstdLevel int ZstdLevelSet bool - WTimeout time.Duration - WTimeoutSet bool - WTimeoutSetFromOption bool - Options map[string][]string UnknownOptions map[string][]string } @@ -650,24 +646,6 @@ func (u *ConnString) addOptions(connectionArgPairs []string) error { u.WString = value u.WNumberSet = false - - case "wtimeoutms": - n, err := strconv.Atoi(value) - if err != nil || n < 0 { - return fmt.Errorf("invalid value for %q: %q", key, value) - } - u.WTimeout = time.Duration(n) * time.Millisecond - u.WTimeoutSet = true - case "wtimeout": - // Defer to wtimeoutms, but not to a manually-set option. - if u.WTimeoutSet { - break - } - n, err := strconv.Atoi(value) - if err != nil || n < 0 { - return fmt.Errorf("invalid value for %q: %q", key, value) - } - u.WTimeout = time.Duration(n) * time.Millisecond case "zlibcompressionlevel": level, err := strconv.Atoi(value) if err != nil || (level < -1 || level > 9) { @@ -1032,11 +1010,6 @@ func (p *parser) parse(original string) (*ConnString, error) { return nil, err } - // If WTimeout was set from manual options passed in, set WTImeoutSet to true. - if connStr.WTimeoutSetFromOption { - connStr.WTimeoutSet = true - } - return connStr, nil } diff --git a/x/mongo/driver/connstring/connstring_spec_test.go b/x/mongo/driver/connstring/connstring_spec_test.go index aea68eba71c..af7b25f3858 100644 --- a/x/mongo/driver/connstring/connstring_spec_test.go +++ b/x/mongo/driver/connstring/connstring_spec_test.go @@ -99,6 +99,8 @@ func runTestsInFile(t *testing.T, dirname string, filename string, warningsError var skipDescriptions = map[string]struct{}{ "Valid options specific to single-threaded drivers are parsed correctly": {}, + // GODRIVER-2348: the wtimeoutMS write concern option is not supported. + "Valid read and write concern are parsed correctly": {}, } var skipKeywords = []string{ @@ -106,6 +108,9 @@ var skipKeywords = []string{ "tlsAllowInvalidCertificates", "tlsDisableCertificateRevocationCheck", "serverSelectionTryOnce", + + // GODRIVER-2348: the wtimeoutMS write concern option is not supported. + "wTimeoutMS", } func runTest(t *testing.T, filename string, test testCase, warningsError bool) { @@ -277,8 +282,6 @@ func verifyConnStringOptions(t *testing.T, cs *connstring.ConnString, options ma } else { require.Equal(t, value, cs.WString) } - case "wtimeoutms": - require.Equal(t, value, float64(cs.WTimeout/time.Millisecond)) case "waitqueuetimeoutms": case "zlibcompressionlevel": require.Equal(t, value, float64(cs.ZlibLevel)) diff --git a/x/mongo/driver/connstring/connstring_test.go b/x/mongo/driver/connstring/connstring_test.go index 84c8ff1d459..001cd72fe54 100644 --- a/x/mongo/driver/connstring/connstring_test.go +++ b/x/mongo/driver/connstring/connstring_test.go @@ -564,33 +564,6 @@ func TestSocketTimeout(t *testing.T) { } } -func TestWTimeout(t *testing.T) { - tests := []struct { - s string - expected time.Duration - err bool - }{ - {s: "wtimeoutMS=10", expected: time.Duration(10) * time.Millisecond}, - {s: "wtimeoutMS=100", expected: time.Duration(100) * time.Millisecond}, - {s: "wtimeoutMS=-2", err: true}, - {s: "wtimeoutMS=gsdge", err: true}, - } - - for _, test := range tests { - s := fmt.Sprintf("mongodb://localhost/?%s", test.s) - t.Run(s, func(t *testing.T) { - cs, err := connstring.ParseAndValidate(s) - if test.err { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, test.expected, cs.WTimeout) - require.True(t, cs.WTimeoutSet) - } - }) - } -} - func TestCompressionOptions(t *testing.T) { tests := []struct { name string diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 16992b40997..b6a95e32da9 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -28,6 +28,12 @@ import ( type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) Kind() description.TopologyKind + + // GetServerSelectionTimeout returns a timeout that should be used to set a + // deadline for server selection. This logic is not handleded internally by + // the ServerSelector, as a resulting deadline may be applicable by follow-up + // operations such as checking out a connection. + GetServerSelectionTimeout() time.Duration } // Connector represents a type that can connect to a server. @@ -144,6 +150,12 @@ func (ssd SingleServerDeployment) SelectServer(context.Context, description.Serv // Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (SingleServerDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle } +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for single server deployments. +func (SingleServerDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} + // SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This // implementation should only be used for connection handshakes and server heartbeats as it does not implement // ErrorProcessor, which is necessary for application operations. @@ -159,6 +171,12 @@ func (scd SingleConnectionDeployment) SelectServer(context.Context, description. return scd, nil } +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for single connection deployment. +func (SingleConnectionDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} + // Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (SingleConnectionDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 3a189318cbe..b12ac5d3969 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -59,8 +59,6 @@ var ( ErrDeadlineWouldBeExceeded = fmt.Errorf( "operation not sent to server, as Timeout would be exceeded: %w", context.DeadlineExceeded) - // ErrNegativeMaxTime is returned when MaxTime on an operation is a negative value. - ErrNegativeMaxTime = errors.New("a negative value was provided for MaxTime on an operation") ) // QueryFailureError is an error representing a command failure as a document. diff --git a/x/mongo/driver/integration/aggregate_test.go b/x/mongo/driver/integration/aggregate_test.go index 824c06f9933..c7cbcfc7d54 100644 --- a/x/mongo/driver/integration/aggregate_test.go +++ b/x/mongo/driver/integration/aggregate_test.go @@ -84,9 +84,13 @@ func TestAggregate(t *testing.T) { op := operation.NewAggregate(bsoncore.BuildDocumentFromElements(nil)). Collection(collName).Database(dbName).Deployment(top).ServerSelector(&serverselector.Write{}). CommandMonitor(monitor).BatchSize(2) - err = op.Execute(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err = op.Execute(ctx) noerr(t, err) - batchCursor, err := op.Result(driver.CursorOptions{MaxTimeMS: 10, BatchSize: 2, CommandMonitor: monitor}) + batchCursor, err := op.Result(driver.CursorOptions{BatchSize: 2, CommandMonitor: monitor}) noerr(t, err) var e *event.CommandStartedEvent diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 84bf6a9fe1d..61110e54671 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -11,6 +11,7 @@ import ( "context" "errors" "fmt" + "math" "net" "strconv" "strings" @@ -48,6 +49,12 @@ var ( ErrNonPrimaryReadPref = errors.New("read preference in a transaction must be primary") // errDatabaseNameEmpty occurs when a database name is not provided. errDatabaseNameEmpty = errors.New("database name cannot be empty") + // errEmptyWriteConcern indicates that a write concern has no fields set. + errEmptyWriteConcern = errors.New("a write concern must have at least one field set") + // errNegativeW indicates that a negative integer `w` field was specified. + errNegativeW = errors.New("write concern `w` field cannot be a negative number") + // errInconsistent indicates that an inconsistent write concern was specified. + errInconsistent = errors.New("a write concern cannot have both w=0 and j=true") ) const ( @@ -280,9 +287,6 @@ type Operation struct { // read preference will not be added to the command on wire versions < 13. IsOutputAggregate bool - // MaxTime specifies the maximum amount of time to allow the operation to run on the server. - MaxTime *time.Duration - // Timeout is the amount of time that this operation can execute before returning an error. The default value // nil, which means that the timeout of the operation's caller will be used. Timeout *time.Duration @@ -293,6 +297,10 @@ type Operation struct { // OP_MSG as well as for logging server selection data. Name string + // OmitMaxTimeMS will ensure that wire messages sent to the server in service + // of the operation do not contain a maxTimeMS field. + OmitMaxTimeMS bool + // omitReadPreference is a boolean that indicates whether to omit the // read preference from the command. This omition includes the case // where a default read preference is used when the operation @@ -408,6 +416,9 @@ func (op Operation) getServerAndConnection( requestID int32, deprioritized []description.Server, ) (Server, *mnet.Connection, error) { + ctx, cancel := csot.WithServerSelectionTimeout(ctx, op.Deployment.GetServerSelectionTimeout()) + defer cancel() + server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && @@ -485,15 +496,8 @@ func (op Operation) Execute(ctx context.Context) error { return err } - // If no deadline is set on the passed-in context, op.Timeout is set, and context is not already - // a Timeout context, honor op.Timeout in new Timeout context for operation execution. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && op.Timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *op.Timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of Execute to avoid a context leak. - defer cancelFunc() - } + ctx, cancel := csot.WithTimeout(ctx, op.Timeout) + defer cancel() if op.Client != nil { if err := op.Client.StartCommand(); err != nil { @@ -1181,6 +1185,7 @@ func (op Operation) addBatchArray(dst []byte) []byte { } func (op Operation) createLegacyHandshakeWireMessage( + ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, @@ -1225,7 +1230,7 @@ func (op Operation) createLegacyHandshakeWireMessage( return dst, info, err } - dst, err = op.addWriteConcern(dst, desc) + dst, err = op.addWriteConcern(ctx, dst, desc) if err != nil { return dst, info, err } @@ -1297,7 +1302,7 @@ func (op Operation) createMsgWireMessage( if err != nil { return dst, info, err } - dst, err = op.addWriteConcern(dst, desc) + dst, err = op.addWriteConcern(ctx, dst, desc) if err != nil { return dst, info, err } @@ -1364,7 +1369,7 @@ func (op Operation) createWireMessage( requestID int32, ) ([]byte, startedInformation, error) { if isLegacyHandshake(op, desc) { - return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) + return op.createLegacyHandshakeWireMessage(ctx, maxTimeMS, dst, desc) } return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) @@ -1474,7 +1479,54 @@ func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) return bsoncore.AppendDocumentElement(dst, "readConcern", data), nil } -func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) ([]byte, error) { +func marshalBSONWriteConcern(wc writeconcern.WriteConcern, wtimeout time.Duration) (bson.Type, []byte, error) { + var elems []byte + if wc.W != nil { + // Only support string or int values for W. That aligns with the + // documentation and the behavior of other functions, like Acknowledged. + switch w := wc.W.(type) { + case int: + if w < 0 { + return 0, nil, errNegativeW + } + + // If Journal=true and W=0, return an error because that write + // concern is ambiguous. + if wc.Journal != nil && *wc.Journal && w == 0 { + return 0, nil, errInconsistent + } + + // Check for lower and upper bounds on architecture-dependent int. + if w > math.MaxInt32 { + return 0, nil, fmt.Errorf("WriteConcern.W overflows int32: %v", wc.W) + } + + elems = bsoncore.AppendInt32Element(elems, "w", int32(w)) + case string: + elems = bsoncore.AppendStringElement(elems, "w", w) + default: + return 0, + nil, + fmt.Errorf("WriteConcern.W must be a string or int, but is a %T", wc.W) + } + } + + if wc.Journal != nil { + elems = bsoncore.AppendBooleanElement(elems, "j", *wc.Journal) + } + + if wtimeout != 0 { + elems = bsoncore.AppendInt64Element(elems, "wtimeout", int64(wtimeout/time.Millisecond)) + } + + if len(elems) == 0 { + return 0, nil, errEmptyWriteConcern + } + + return bson.TypeEmbeddedDocument, bsoncore.BuildDocument(nil, elems), nil +} + +func (op Operation) addWriteConcern(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, op.MinimumWriteConcernWireVersion)) { @@ -1485,15 +1537,27 @@ func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) return dst, nil } - t, data, err := wc.MarshalBSONValue() - if errors.Is(err, writeconcern.ErrEmptyWriteConcern) { + // The specifications for committing a transaction states: + // + // > if the modified write concern does not include a wtimeout value, drivers + // > MUST also apply wtimeout: 10000 to the write concern in order to avoid + // > waiting forever (or until a socket timeout) if the majority write concern + // > cannot be satisfied. + var wtimeout time.Duration + if _, ok := ctx.Deadline(); op.Client != nil && op.Timeout == nil && !ok { + wtimeout = op.Client.CurrentWTimeout + } + + typ, wcBSON, err := marshalBSONWriteConcern(*wc, wtimeout) + if errors.Is(err, errEmptyWriteConcern) { return dst, nil } + if err != nil { return dst, err } - return append(bsoncore.AppendHeader(dst, bsoncore.Type(t), "writeConcern"), data...), nil + return append(bsoncore.AppendHeader(dst, bsoncore.Type(typ), "writeConcern"), wcBSON...), nil } func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) { @@ -1557,34 +1621,29 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is // not a Timeout context, calculateMaxTimeMS returns 0. func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (uint64, error) { - if csot.IsTimeoutContext(ctx) { - if deadline, ok := ctx.Deadline(); ok { - remainingTimeout := time.Until(deadline) - - // Always round up to the next millisecond value so we never truncate the calculated - // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) - if maxTimeMS <= 0 { - return 0, fmt.Errorf( - "remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v", - remainingTimeout, - ErrDeadlineWouldBeExceeded, - rttStats) - } + if op.OmitMaxTimeMS { + return 0, nil + } - return uint64(maxTimeMS), nil - } - } else if op.MaxTime != nil { - // Users are not allowed to pass a negative value as MaxTime. A value of 0 would indicate - // no timeout and is allowed. - if *op.MaxTime < 0 { - return 0, ErrNegativeMaxTime - } - // Always round up to the next millisecond value so we never truncate the requested - // MaxTime value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - return uint64((*op.MaxTime + (time.Millisecond - 1)) / time.Millisecond), nil + deadline, ok := ctx.Deadline() + if !ok { + return 0, nil } - return 0, nil + + remainingTimeout := time.Until(deadline) + + // Always round up to the next millisecond value so we never truncate the calculated + // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). + maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) + if maxTimeMS <= 0 { + return 0, fmt.Errorf( + "remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v", + remainingTimeout, + ErrDeadlineWouldBeExceeded, + rttStats) + } + + return uint64(maxTimeMS), nil } // updateClusterTimes updates the cluster times for the session and cluster clock attached to this diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 3fe4ca2fe31..92c0186a494 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -30,7 +30,6 @@ type Aggregate struct { collation bsoncore.Document comment bsoncore.Value hint bsoncore.Value - maxTime *time.Duration pipeline bsoncore.Document session *session.Client clock *session.ClusterClock @@ -109,7 +108,6 @@ func (a *Aggregate) Execute(ctx context.Context) error { MinimumWriteConcernWireVersion: 5, ServerAPI: a.serverAPI, IsOutputAggregate: a.hasOutputStage, - MaxTime: a.maxTime, Timeout: a.timeout, Name: driverutil.AggregateOp, }.Execute(ctx) @@ -224,16 +222,6 @@ func (a *Aggregate) Hint(hint bsoncore.Value) *Aggregate { return a } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (a *Aggregate) MaxTime(maxTime *time.Duration) *Aggregate { - if a == nil { - a = new(Aggregate) - } - - a.maxTime = maxTime - return a -} - // Pipeline determines how data is transformed for an aggregation. func (a *Aggregate) Pipeline(pipeline bsoncore.Document) *Aggregate { if a == nil { diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 42a79e2f562..b014affd154 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -9,7 +9,6 @@ package operation import ( "context" "errors" - "time" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" @@ -22,7 +21,6 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { - maxTime *time.Duration recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -63,7 +61,6 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { Crypt: ct.crypt, Database: ct.database, Deployment: ct.deployment, - MaxTime: ct.maxTime, Selector: ct.selector, WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, @@ -81,16 +78,6 @@ func (ct *CommitTransaction) command(dst []byte, _ description.SelectedServer) ( return dst, nil } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (ct *CommitTransaction) MaxTime(maxTime *time.Duration) *CommitTransaction { - if ct == nil { - ct = new(CommitTransaction) - } - - ct.maxTime = maxTime - return ct -} - // RecoveryToken sets the recovery token to use when committing or aborting a sharded transaction. func (ct *CommitTransaction) RecoveryToken(recoveryToken bsoncore.Document) *CommitTransaction { if ct == nil { diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 5625b79bd90..6aac998bf56 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -24,7 +24,6 @@ import ( // Count represents a count operation. type Count struct { - maxTime *time.Duration query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -120,7 +119,6 @@ func (c *Count) Execute(ctx context.Context) error { Crypt: c.crypt, Database: c.database, Deployment: c.deployment, - MaxTime: c.maxTime, ReadConcern: c.readConcern, ReadPreference: c.readPreference, Selector: c.selector, @@ -150,16 +148,6 @@ func (c *Count) command(dst []byte, _ description.SelectedServer) ([]byte, error return dst, nil } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (c *Count) MaxTime(maxTime *time.Duration) *Count { - if c == nil { - c = new(Count) - } - - c.maxTime = maxTime - return c -} - // Query determines what results are returned from find. func (c *Count) Query(query bsoncore.Document) *Count { if c == nil { diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 0192379e2bc..06c8fd81184 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -25,7 +25,6 @@ import ( type CreateIndexes struct { commitQuorum bsoncore.Value indexes bsoncore.Document - maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -112,7 +111,6 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { Crypt: ci.crypt, Database: ci.database, Deployment: ci.deployment, - MaxTime: ci.maxTime, Selector: ci.selector, WriteConcern: ci.writeConcern, ServerAPI: ci.serverAPI, @@ -158,16 +156,6 @@ func (ci *CreateIndexes) Indexes(indexes bsoncore.Document) *CreateIndexes { return ci } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (ci *CreateIndexes) MaxTime(maxTime *time.Duration) *CreateIndexes { - if ci == nil { - ci = new(CreateIndexes) - } - - ci.maxTime = maxTime - return ci -} - // Session sets the session for this operation. func (ci *CreateIndexes) Session(session *session.Client) *CreateIndexes { if ci == nil { diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index a13bd2b7b44..a59e4ced357 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -25,7 +25,6 @@ import ( type Distinct struct { collation bsoncore.Document key *string - maxTime *time.Duration query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -99,7 +98,6 @@ func (d *Distinct) Execute(ctx context.Context) error { Crypt: d.crypt, Database: d.database, Deployment: d.deployment, - MaxTime: d.maxTime, ReadConcern: d.readConcern, ReadPreference: d.readPreference, Selector: d.selector, @@ -150,16 +148,6 @@ func (d *Distinct) Key(key string) *Distinct { return d } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (d *Distinct) MaxTime(maxTime *time.Duration) *Distinct { - if d == nil { - d = new(Distinct) - } - - d.maxTime = maxTime - return d -} - // Query specifies which documents to return distinct values from. func (d *Distinct) Query(query bsoncore.Document) *Distinct { if d == nil { diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 597d04ac88b..a758f34970f 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -24,7 +24,6 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { index *string - maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -95,7 +94,6 @@ func (di *DropIndexes) Execute(ctx context.Context) error { Crypt: di.crypt, Database: di.database, Deployment: di.deployment, - MaxTime: di.maxTime, Selector: di.selector, WriteConcern: di.writeConcern, ServerAPI: di.serverAPI, @@ -123,16 +121,6 @@ func (di *DropIndexes) Index(index string) *DropIndexes { return di } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (di *DropIndexes) MaxTime(maxTime *time.Duration) *DropIndexes { - if di == nil { - di = new(DropIndexes) - } - - di.maxTime = maxTime - return di -} - // Session sets the session for this operation. func (di *DropIndexes) Session(session *session.Client) *DropIndexes { if di == nil { diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 1e34b8da8a4..bdbad6d610c 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -35,7 +35,6 @@ type Find struct { let bsoncore.Document limit *int64 max bsoncore.Document - maxTime *time.Duration min bsoncore.Document noCursorTimeout *bool oplogReplay *bool @@ -100,7 +99,6 @@ func (f *Find) Execute(ctx context.Context) error { Crypt: f.crypt, Database: f.database, Deployment: f.deployment, - MaxTime: f.maxTime, ReadConcern: f.readConcern, ReadPreference: f.readPreference, Selector: f.selector, @@ -299,16 +297,6 @@ func (f *Find) Max(max bsoncore.Document) *Find { return f } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (f *Find) MaxTime(maxTime *time.Duration) *Find { - if f == nil { - f = new(Find) - } - - f.maxTime = maxTime - return f -} - // Min sets an inclusive lower bound for a specific index. func (f *Find) Min(min bsoncore.Document) *Find { if f == nil { diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 12d241f7101..51af9ffbcf2 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -29,7 +29,6 @@ type FindAndModify struct { collation bsoncore.Document comment bsoncore.Value fields bsoncore.Document - maxTime *time.Duration newDocument *bool query bsoncore.Document remove *bool @@ -137,7 +136,6 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { CommandMonitor: fam.monitor, Database: fam.database, Deployment: fam.deployment, - MaxTime: fam.maxTime, Selector: fam.selector, WriteConcern: fam.writeConcern, Crypt: fam.crypt, @@ -265,16 +263,6 @@ func (fam *FindAndModify) Fields(fields bsoncore.Document) *FindAndModify { return fam } -// MaxTime specifies the maximum amount of time to allow the operation to run on the server. -func (fam *FindAndModify) MaxTime(maxTime *time.Duration) *FindAndModify { - if fam == nil { - fam = new(FindAndModify) - } - - fam.maxTime = maxTime - return fam -} - // NewDocument specifies whether to return the modified document or the original. Defaults to false (return original). func (fam *FindAndModify) NewDocument(newDocument bool) *FindAndModify { if fam == nil { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 8e6c59de384..9a3993120f1 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -47,6 +47,7 @@ type Hello struct { maxAwaitTimeMS *int64 serverAPI *driver.ServerAPIOptions loadBalanced bool + omitMaxTimeMS bool res bsoncore.Document } @@ -590,7 +591,8 @@ func (h *Hello) createOperation() driver.Operation { h.res = info.ServerResponse return nil }, - ServerAPI: h.serverAPI, + ServerAPI: h.serverAPI, + OmitMaxTimeMS: h.omitMaxTimeMS, } if isLegacyHandshake(h.serverAPI, h.loadBalanced) { @@ -650,3 +652,15 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, func (h *Hello) FinishHandshake(context.Context, *mnet.Connection) error { return nil } + +// OmitMaxTimeMS will ensure maxTimMS is not included in the wire message +// constructed to send a hello request. +func (h *Hello) OmitMaxTimeMS(val bool) *Hello { + if h == nil { + h = new(Hello) + } + + h.omitMaxTimeMS = val + + return h +} diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index d4cbe8a3375..a14873a7ac8 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -22,7 +22,6 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { batchSize *int32 - maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -76,7 +75,6 @@ func (li *ListIndexes) Execute(ctx context.Context) error { CommandMonitor: li.monitor, Database: li.database, Deployment: li.deployment, - MaxTime: li.maxTime, Selector: li.selector, Crypt: li.crypt, Legacy: driver.LegacyListIndexes, @@ -113,16 +111,6 @@ func (li *ListIndexes) BatchSize(batchSize int32) *ListIndexes { return li } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (li *ListIndexes) MaxTime(maxTime *time.Duration) *ListIndexes { - if li == nil { - li = new(ListIndexes) - } - - li.maxTime = maxTime - return li -} - // Session sets the session for this operation. func (li *ListIndexes) Session(session *session.Client) *ListIndexes { if li == nil { diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 0e3da7007ca..f209134b793 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -230,7 +230,8 @@ func TestOperation(t *testing.T) { want := bsoncore.AppendDocumentElement(nil, "writeConcern", bsoncore.BuildDocumentFromElements( nil, bsoncore.AppendStringElement(nil, "w", "majority"), )) - got, err := Operation{WriteConcern: writeconcern.Majority()}.addWriteConcern(nil, description.SelectedServer{}) + got, err := Operation{WriteConcern: writeconcern.Majority()}. + addWriteConcern(context.Background(), nil, description.SelectedServer{}) noerr(t, err) if !bytes.Equal(got, want) { t.Errorf("WriteConcern elements do not match. got %v; want %v", got, want) @@ -270,15 +271,12 @@ func TestOperation(t *testing.T) { }) t.Run("calculateMaxTimeMS", func(t *testing.T) { var ( - timeout = 5 * time.Second - maxTime = 2 * time.Second - negMaxTime = -2 * time.Second - shortRTT = 50 * time.Millisecond - longRTT = 10 * time.Second - verShortRTT = 400 * time.Microsecond + timeout = 5 * time.Second + shortRTT = 50 * time.Millisecond + longRTT = 10 * time.Second ) - timeoutCtx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) + timeoutCtx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() testCases := []struct { @@ -293,43 +291,14 @@ func TestOperation(t *testing.T) { }{ { name: "uses context deadline and rtt90 with timeout", - op: Operation{MaxTime: &maxTime}, ctx: timeoutCtx, rttMin: shortRTT, rttStats: "", want: 5000, err: nil, }, - { - name: "uses MaxTime without timeout", - op: Operation{MaxTime: &maxTime}, - ctx: context.Background(), - rttMin: longRTT, - rttStats: "", - want: 2000, - err: nil, - }, - { - name: "errors when remaining timeout is less than rtt90", - op: Operation{MaxTime: &maxTime}, - ctx: timeoutCtx, - rttMin: timeout, - rttStats: "", - want: 0, - err: ErrDeadlineWouldBeExceeded, - }, - { - name: "errors when MaxTime is negative", - op: Operation{MaxTime: &negMaxTime}, - ctx: context.Background(), - rttMin: longRTT, - rttStats: "", - want: 0, - err: ErrNegativeMaxTime, - }, { name: "sub millisecond rtt should round up", - op: Operation{MaxTime: &verShortRTT}, ctx: context.Background(), rttMin: longRTT, rttStats: "", @@ -651,7 +620,7 @@ func TestOperation(t *testing.T) { assert.NotNil(t, err, "expected an error from Execute(), got nil") // Assert that error is just context deadline exceeded and is therefore not a driver.Error marked // with the TransientTransactionError label. - assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) }) t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) { conn := mnet.NewConnection(&mockConnection{}) @@ -710,18 +679,24 @@ type mockDeployment struct { selector description.ServerSelector } returns struct { - server Server - err error - retry bool - kind description.TopologyKind + server Server + err error + retry bool + kind description.TopologyKind + serverSelectionTimeout time.Duration } } func (m *mockDeployment) SelectServer(_ context.Context, desc description.ServerSelector) (Server, error) { m.params.selector = desc + return m.returns.server, m.returns.err } +func (m *mockDeployment) GetServerSelectionTimeout() time.Duration { + return m.returns.serverSelectionTimeout +} + func (m *mockDeployment) Kind() description.TopologyKind { return m.returns.kind } type mockServerSelector struct{} @@ -974,3 +949,67 @@ func TestFilterDeprioritizedServers(t *testing.T) { }) } } + +func TestMarshalBSONWriteConcern(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + writeConcern writeconcern.WriteConcern + wantBSONType bson.Type + wtimeout time.Duration + want bson.D + wantErr string + }{ + { + name: "empty", + writeConcern: writeconcern.WriteConcern{}, + wantBSONType: 0x0, + want: nil, + wtimeout: 0, + wantErr: "a write concern must have at least one field set", + }, + { + name: "journal only", + writeConcern: *writeconcern.Journaled(), + wantBSONType: bson.TypeEmbeddedDocument, + want: bson.D{{"j", true}}, + wtimeout: 0, + wantErr: "a write concern must have at least one field set", + }, + { + name: "journal and wtimout", + writeConcern: *writeconcern.Journaled(), + wtimeout: 10 * time.Millisecond, + wantBSONType: bson.TypeEmbeddedDocument, + want: bson.D{{"j", true}, {"wtimeout", int64(10 * time.Millisecond / time.Millisecond)}}, + wantErr: "a write concern must have at least one field set", + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + gotBSONType, gotBSON, gotErr := marshalBSONWriteConcern(test.writeConcern, test.wtimeout) + assert.Equal(t, test.wantBSONType, gotBSONType) + + wantBSON := []byte(nil) + + if test.want != nil { + var err error + + wantBSON, err = bson.Marshal(test.want) + assert.NoError(t, err) + } + + assert.Equal(t, wantBSON, gotBSON) + + if gotErr != nil { + assert.EqualError(t, gotErr, test.wantErr) + } + }) + } +} diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index d535ec54c9f..4228c4e98f9 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -57,6 +57,8 @@ const ( Aborted ) +const defaultWriteConcernTimeout = 10_000 * time.Millisecond + // String implements the fmt.Stringer interface. func (s TransactionState) String() string { switch s { @@ -104,16 +106,15 @@ type Client struct { // options for the current transaction // most recently set by transactionopt - CurrentRc *readconcern.ReadConcern - CurrentRp *readpref.ReadPref - CurrentWc *writeconcern.WriteConcern - CurrentMct *time.Duration + CurrentRc *readconcern.ReadConcern + CurrentRp *readpref.ReadPref + CurrentWc *writeconcern.WriteConcern + CurrentWTimeout time.Duration // default transaction options - transactionRc *readconcern.ReadConcern - transactionRp *readpref.ReadPref - transactionWc *writeconcern.WriteConcern - transactionMaxCommitTime *time.Duration + transactionRc *readconcern.ReadConcern + transactionRp *readpref.ReadPref + transactionWc *writeconcern.WriteConcern pool *Pool TransactionState TransactionState @@ -189,9 +190,6 @@ func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (* if mergedOpts.DefaultWriteConcern != nil { c.transactionWc = mergedOpts.DefaultWriteConcern } - if mergedOpts.DefaultMaxCommitTime != nil { - c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime - } if mergedOpts.Snapshot != nil { c.Snapshot = *mergedOpts.Snapshot } @@ -399,7 +397,6 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { c.CurrentRc = opts.ReadConcern c.CurrentRp = opts.ReadPreference c.CurrentWc = opts.WriteConcern - c.CurrentMct = opts.MaxCommitTime } if c.CurrentRc == nil { @@ -414,10 +411,6 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { c.CurrentWc = c.transactionWc } - if c.CurrentMct == nil { - c.CurrentMct = c.transactionMaxCommitTime - } - if !c.CurrentWc.Acknowledged() { _ = c.clearTransactionOpts() return ErrUnackWCUnsupported @@ -449,21 +442,22 @@ func (c *Client) CommitTransaction() error { return nil } -// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set a -// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a -// retryable error or after a successful commit transaction operation. +// UpdateCommitTransactionWriteConcern will set the write concern to majority. +// This should be called after a commit transaction operation fails with a +// retryable error or after a successful commit transaction operation +// +// Per the transaction specifications, when commitTransaction is retried, if +// the modified write concern does not include a "wtimeout" value, drivers +// MUST apply "wtimeout: 10000" to the write concern in order to avoid waiting +// forever (oruntil a socket timeout) if the majority write concern cannot be +// satisfied. This field abstracts that functionality. For more information, +// see SPEC-1185. func (c *Client) UpdateCommitTransactionWriteConcern() { - wc := &writeconcern.WriteConcern{} - timeout := 10 * time.Second - if c.CurrentWc != nil { - *wc = *c.CurrentWc - if c.CurrentWc.WTimeout != 0 { - timeout = c.CurrentWc.WTimeout - } + c.CurrentWc = &writeconcern.WriteConcern{ + W: "majority", } - wc.W = "majority" - wc.WTimeout = timeout - c.CurrentWc = wc + + c.CurrentWTimeout = defaultWriteConcernTimeout } // CheckAbortTransaction checks to see if allowed to abort transaction and returns diff --git a/x/mongo/driver/session/options.go b/x/mongo/driver/session/options.go index ee7c301d649..67749f09cb9 100644 --- a/x/mongo/driver/session/options.go +++ b/x/mongo/driver/session/options.go @@ -7,8 +7,6 @@ package session import ( - "time" - "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -20,7 +18,6 @@ type ClientOptions struct { DefaultReadConcern *readconcern.ReadConcern DefaultWriteConcern *writeconcern.WriteConcern DefaultReadPreference *readpref.ReadPref - DefaultMaxCommitTime *time.Duration Snapshot *bool } @@ -29,7 +26,6 @@ type TransactionOptions struct { ReadConcern *readconcern.ReadConcern WriteConcern *writeconcern.WriteConcern ReadPreference *readpref.ReadPref - MaxCommitTime *time.Duration } func mergeClientOptions(opts ...*ClientOptions) *ClientOptions { @@ -50,9 +46,6 @@ func mergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.DefaultWriteConcern != nil { c.DefaultWriteConcern = opt.DefaultWriteConcern } - if opt.DefaultMaxCommitTime != nil { - c.DefaultMaxCommitTime = opt.DefaultMaxCommitTime - } if opt.Snapshot != nil { c.Snapshot = opt.Snapshot } diff --git a/x/mongo/driver/topology/CMAP_spec_test.go b/x/mongo/driver/topology/CMAP_spec_test.go index 62283d21567..d65c97ca3d3 100644 --- a/x/mongo/driver/topology/CMAP_spec_test.go +++ b/x/mongo/driver/topology/CMAP_spec_test.go @@ -208,7 +208,7 @@ func runCMAPTest(t *testing.T, testFileName string) { } })) - s := NewServer("mongodb://fake", bson.NewObjectID(), sOpts...) + s := NewServer("mongodb://fake", bson.NewObjectID(), defaultConnectionTimeout, sOpts...) s.state = serverConnected require.NoError(t, err, "error connecting connection pool") defer s.pool.close(context.Background()) @@ -274,7 +274,6 @@ func runCMAPTest(t *testing.T, testFileName string) { } checkEvents(t, test.Events, testInfo.finalEventChan, test.Ignore) - } func checkEvents(t *testing.T, expectedEvents []cmapEvent, actualEvents chan *event.PoolEvent, ignoreEvents []string) { @@ -290,7 +289,6 @@ func checkEvents(t *testing.T, expectedEvents []cmapEvent, actualEvents chan *ev } if expectedEvent.Address != nil { - if expectedEvent.Address == float64(42) { // can be any address if validEvent.Address == "" { t.Errorf("expected address in event, instead received none in %v", expectedEvent.EventType) diff --git a/x/mongo/driver/topology/cancellation_listener.go b/x/mongo/driver/topology/cancellation_listener.go deleted file mode 100644 index caca988057a..00000000000 --- a/x/mongo/driver/topology/cancellation_listener.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package topology - -import "context" - -type cancellationListener interface { - Listen(context.Context, func()) - StopListening() bool -} diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 43d45c1515c..cd35c6f66d1 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -56,8 +56,6 @@ type connection struct { addr address.Address idleTimeout time.Duration idleDeadline atomic.Value // Stores a time.Time - readTimeout time.Duration - writeTimeout time.Duration desc description.Server helloRTT time.Duration compressor wiremessage.CompressorID @@ -65,13 +63,13 @@ type connection struct { zstdLevel int connectDone chan struct{} config *connectionConfig - cancelConnectContext context.CancelFunc connectContextMade chan struct{} canStream bool currentlyStreaming bool - connectContextMutex sync.Mutex - cancellationListener cancellationListener - serverConnectionID *int64 // the server's ID for this client's connection + cancellationListener contextListener + connectListener contextListener // Cancels blocking ops during connect + serverConnectionID *int64 // the server's ID for this client's connection + prevCanceled atomic.Value // pool related fields pool *pool @@ -90,12 +88,11 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { id: id, addr: addr, idleTimeout: cfg.idleTimeout, - readTimeout: cfg.readTimeout, - writeTimeout: cfg.writeTimeout, connectDone: make(chan struct{}), config: cfg, connectContextMade: make(chan struct{}), - cancellationListener: newCancellListener(), + cancellationListener: newContextDoneListener(), + connectListener: newNonBlockingContextDoneListener(), } // Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered // at any point during connection establishment can be processed without the connection being considered stale. @@ -141,6 +138,7 @@ func (c *connection) connect(ctx context.Context) (err error) { return nil } + defer c.closeConnectContext() defer close(c.connectDone) // If connect returns an error, set the connection status as disconnected and close the @@ -165,35 +163,17 @@ func (c *connection) connect(ctx context.Context) (err error) { // cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket // establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid // holding the lock longer than necessary. - c.connectContextMutex.Lock() - var handshakeCtx context.Context - handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx) - c.connectContextMutex.Unlock() + ctx, cancel := context.WithCancel(ctx) + defer cancel() - dialCtx := handshakeCtx - var dialCancel context.CancelFunc - if c.config.connectTimeout != 0 { - dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout) - defer dialCancel() - } - - defer func() { - var cancelFn context.CancelFunc + go func() { + defer cancel() - c.connectContextMutex.Lock() - cancelFn = c.cancelConnectContext - c.cancelConnectContext = nil - c.connectContextMutex.Unlock() - - if cancelFn != nil { - cancelFn() - } + c.connectListener.Listen(ctx, func() {}) }() - close(c.connectContextMade) - // Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case. - tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String()) + tempNc, err := c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String()) if err != nil { return ConnectionError{Wrapped: err, init: true} } @@ -209,7 +189,8 @@ func (c *connection) connect(ctx context.Context) (err error) { DisableEndpointChecking: c.config.disableOCSPEndpointCheck, HTTPClient: c.config.httpClient, } - tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) + tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) + if err != nil { return ConnectionError{Wrapped: err, init: true} } @@ -226,10 +207,9 @@ func (c *connection) connect(ctx context.Context) (err error) { handshakeStartTime := time.Now() iconn := initConnection{c} - handshakeConn := mnet.NewConnection(iconn) - handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn) + handshakeInfo, err = handshaker.GetHandshakeInformation(ctx, c.addr, handshakeConn) if err == nil { // We only need to retain the Description field as the connection's description. The authentication-related // fields in handshakeInfo are tracked by the handshaker if necessary. @@ -253,7 +233,7 @@ func (c *connection) connect(ctx context.Context) (err error) { // If we successfully finished the first part of the handshake and verified LB state, continue with the rest of // the handshake. - err = handshaker.FinishHandshake(handshakeCtx, handshakeConn) + err = handshaker.FinishHandshake(ctx, handshakeConn) } // We have a failed handshake here @@ -299,16 +279,8 @@ func (c *connection) wait() { } func (c *connection) closeConnectContext() { - <-c.connectContextMade - var cancelFn context.CancelFunc - - c.connectContextMutex.Lock() - cancelFn = c.cancelConnectContext - c.cancelConnectContext = nil - c.connectContextMutex.Unlock() - - if cancelFn != nil { - cancelFn() + if c.connectListener != nil { + c.connectListener.StopListening() } } @@ -347,17 +319,7 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { } } - var deadline time.Time - if c.writeTimeout != 0 { - deadline = time.Now().Add(c.writeTimeout) - } - - var contextDeadlineUsed bool - if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) { - contextDeadlineUsed = true - deadline = dl - } - + deadline, contextDeadlineUsed := ctx.Deadline() if err := c.nc.SetWriteDeadline(deadline); err != nil { return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"} } @@ -401,17 +363,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { } } - var deadline time.Time - if c.readTimeout != 0 { - deadline = time.Now().Add(c.readTimeout) - } - - var contextDeadlineUsed bool - if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) { - contextDeadlineUsed = true - deadline = dl - } - + deadline, contextDeadlineUsed := ctx.Deadline() if err := c.nc.SetReadDeadline(deadline); err != nil { return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"} } @@ -484,6 +436,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, } func (c *connection) close() error { + // Stop any blocking operations occurring in connect(), but await closing the + // connections directly before closing the connection context. This ensures + // that closing a connection will manifest as an io.EOF error, avoiding + // non-deterministic connection closure errors. + defer c.closeConnectContext() + // Overwrite the connection state as the first step so only the first close call will execute. if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) { return nil @@ -535,11 +493,6 @@ func (c *connection) getCurrentlyStreaming() bool { return c.currentlyStreaming } -func (c *connection) setSocketTimeout(timeout time.Duration) { - c.readTimeout = timeout - c.writeTimeout = timeout -} - func (c *connection) ID() string { return c.id } @@ -548,6 +501,14 @@ func (c *connection) ServerConnectionID() *int64 { return c.serverConnectionID } +func (c *connection) previousCanceled() bool { + if val := c.prevCanceled.Load(); val != nil { + return val.(bool) + } + + return false +} + // initConnection is an adapter used during connection initialization. It has the minimum // functionality necessary to implement the driver.Connection interface, which is required to pass a // *connection to a Handshaker. @@ -851,47 +812,3 @@ func configureTLS(ctx context.Context, } return client, nil } - -// TODO: Naming? - -// cancellListener listens for context cancellation and notifies listeners via a -// callback function. -type cancellListener struct { - aborted bool - done chan struct{} -} - -// newCancellListener constructs a cancellListener. -func newCancellListener() *cancellListener { - return &cancellListener{ - done: make(chan struct{}), - } -} - -// Listen blocks until the provided context is cancelled or listening is aborted -// via the StopListening function. If this detects that the context has been -// cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback is -// called to abort in-progress work. Even if the context expires, this function -// will block until StopListening is called. -func (c *cancellListener) Listen(ctx context.Context, abortFn func()) { - c.aborted = false - - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.Canceled) { - c.aborted = true - abortFn() - } - - <-c.done - case <-c.done: - } -} - -// StopListening stops the in-progress Listen call. This blocks if there is no -// in-progress Listen call. This function will return true if the provided abort -// callback was called when listening for cancellation on the previous context. -func (c *cancellListener) StopListening() bool { - c.done <- struct{}{} - return c.aborted -} diff --git a/x/mongo/driver/topology/connection_options.go b/x/mongo/driver/topology/connection_options.go index 41533a149a1..f45da5d4608 100644 --- a/x/mongo/driver/topology/connection_options.go +++ b/x/mongo/driver/topology/connection_options.go @@ -48,13 +48,10 @@ type Handshaker = driver.Handshaker type generationNumberFn func(serviceID *bson.ObjectID) uint64 type connectionConfig struct { - connectTimeout time.Duration dialer Dialer handshaker Handshaker idleTimeout time.Duration cmdMonitor *event.CommandMonitor - readTimeout time.Duration - writeTimeout time.Duration tlsConfig *tls.Config httpClient *http.Client compressors []string @@ -69,7 +66,6 @@ type connectionConfig struct { func newConnectionConfig(opts ...ConnectionOption) *connectionConfig { cfg := &connectionConfig{ - connectTimeout: 30 * time.Second, dialer: nil, tlsConnectionSource: defaultTLSConnectionSource, httpClient: httputil.DefaultHTTPClient, @@ -107,14 +103,6 @@ func WithCompressors(fn func([]string) []string) ConnectionOption { } } -// WithConnectTimeout configures the maximum amount of time a dial will wait for a -// Connect to complete. The default is 30 seconds. -func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption { - return func(c *connectionConfig) { - c.connectTimeout = fn(c.connectTimeout) - } -} - // WithDialer configures the Dialer to use when making a new connection to MongoDB. func WithDialer(fn func(Dialer) Dialer) ConnectionOption { return func(c *connectionConfig) { @@ -137,20 +125,6 @@ func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption { } } -// WithReadTimeout configures the maximum read time for a connection. -func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption { - return func(c *connectionConfig) { - c.readTimeout = fn(c.readTimeout) - } -} - -// WithWriteTimeout configures the maximum write time for a connection. -func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption { - return func(c *connectionConfig) { - c.writeTimeout = fn(c.writeTimeout) - } -} - // WithTLSConfig configures the TLS options for a connection. func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption { return func(c *connectionConfig) { diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 51bba47419a..b5158c596d0 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -118,7 +118,6 @@ func TestConnection(t *testing.T) { err := conn.connect(context.Background()) assert.Nil(t, err, "error establishing connection: %v", err) - assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared") }) t.Run("connect cancelled", func(t *testing.T) { // In the case where connection establishment is cancelled, the closeConnectContext function @@ -149,7 +148,6 @@ func TestConnection(t *testing.T) { // Simulate cancelling connection establishment and assert that this clears the CancelFunc. conn.closeConnectContext() - assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared") close(doneChan) wg.Wait() }) @@ -203,144 +201,6 @@ func TestConnection(t *testing.T) { } }) }) - t.Run("connectTimeout is applied correctly", func(t *testing.T) { - testCases := []struct { - name string - contextTimeout time.Duration - connectTimeout time.Duration - maxConnectTime time.Duration - }{ - // The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for - // both of the tests declared below. Both tests also specify a 50ms max connect time to provide - // a large buffer for lag and avoid test flakiness. - - {"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 50 * time.Millisecond}, - {"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 50 * time.Millisecond}, - } - - for _, tc := range testCases { - t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) { - // Ensure the initial connection dial can be timed out and the connection propagates the error - // from the dialer in this case. - - connOpts := []ConnectionOption{ - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) { - <-ctx.Done() - return nil, ctx.Err() - }) - }), - WithConnectTimeout(func(time.Duration) time.Duration { - return tc.connectTimeout - }), - } - conn := newConnection("", connOpts...) - - var connectErr error - callback := func(ctx context.Context) { - connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout) - defer cancel() - - connectErr = conn.connect(connectCtx) - } - assert.Soon(t, callback, tc.maxConnectTime) - - ce, ok := connectErr.(ConnectionError) - assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) - assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v", - context.DeadlineExceeded, ce.Unwrap()) - }) - t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) { - // Ensure the TLS handshake can be timed out and the connection propagates the error from the - // tlsConn in this case. - - // Start a TCP listener on a random port and use the listener address as the - // target for connections. The listener will act as a source of connections - // that never respond, allowing the timeout logic to always trigger. - l, err := net.Listen("tcp", "localhost:0") - assert.Nil(t, err, "net.Listen() error: %q", err) - defer l.Close() - - connOpts := []ConnectionOption{ - WithConnectTimeout(func(time.Duration) time.Duration { - return tc.connectTimeout - }), - WithTLSConfig(func(*tls.Config) *tls.Config { - return &tls.Config{ServerName: "test"} - }), - } - conn := newConnection(address.Address(l.Addr().String()), connOpts...) - - var connectErr error - callback := func(ctx context.Context) { - connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout) - defer cancel() - - connectErr = conn.connect(connectCtx) - } - assert.Soon(t, callback, tc.maxConnectTime) - - ce, ok := connectErr.(ConnectionError) - assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) - - isTimeout := func(err error) bool { - if errors.Is(err, context.DeadlineExceeded) { - return true - } - if ne, ok := err.(net.Error); ok { - return ne.Timeout() - } - return false - } - assert.True(t, - isTimeout(ce.Unwrap()), - "expected wrapped error to be a timeout error, but got %q", - ce.Unwrap()) - }) - t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) { - // Ensure that no additional timeout is applied to the handshake after the connection has been - // established. - - var getInfoCtx, finishCtx context.Context - handshaker := &testHandshaker{ - getHandshakeInformation: func(ctx context.Context, _ address.Address, _ *mnet.Connection) (driver.HandshakeInformation, error) { - getInfoCtx = ctx - return driver.HandshakeInformation{}, nil - }, - finishHandshake: func(ctx context.Context, _ *mnet.Connection) error { - finishCtx = ctx - return nil - }, - } - - connOpts := []ConnectionOption{ - WithConnectTimeout(func(time.Duration) time.Duration { - return tc.connectTimeout - }), - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return &net.TCPConn{}, nil - }) - }), - WithHandshaker(func(Handshaker) Handshaker { - return handshaker - }), - } - conn := newConnection("", connOpts...) - - err := conn.connect(context.Background()) - assert.Nil(t, err, "connect error: %v", err) - - assertNoContextTimeout := func(t *testing.T, ctx context.Context) { - t.Helper() - dl, ok := ctx.Deadline() - assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl) - } - assertNoContextTimeout(t, getInfoCtx) - assertNoContextTimeout(t, finishCtx) - }) - } - }) }) t.Run("writeWireMessage", func(t *testing.T) { t.Run("closed connection", func(t *testing.T) { @@ -355,14 +215,10 @@ func TestConnection(t *testing.T) { testCases := []struct { name string ctxDeadline time.Duration - timeout time.Duration deadline time.Time }{ - {"no deadline", 0, 0, time.Now().Add(1 * time.Second)}, - {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)}, - {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)}, - {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)}, - {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)}, + {"no deadline", 0, time.Now().Add(1 * time.Second)}, + {"ctx deadline", 5 * time.Second, time.Now().Add(6 * time.Second)}, } for _, tc := range testCases { @@ -379,7 +235,7 @@ func TestConnection(t *testing.T) { message: "failed to set write deadline", } tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")} - conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected} got := conn.writeWireMessage(ctx, []byte{}) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) @@ -484,14 +340,10 @@ func TestConnection(t *testing.T) { testCases := []struct { name string ctxDeadline time.Duration - timeout time.Duration deadline time.Time }{ - {"no deadline", 0, 0, time.Now().Add(1 * time.Second)}, - {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)}, - {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)}, - {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)}, - {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)}, + {"no deadline", 0, time.Now().Add(1 * time.Second)}, + {"ctx deadline", 5 * time.Second, time.Now().Add(6 * time.Second)}, } for _, tc := range testCases { @@ -508,7 +360,7 @@ func TestConnection(t *testing.T) { message: "failed to set read deadline", } tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")} - conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected} _, got := conn.readWireMessage(ctx) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) @@ -776,7 +628,8 @@ func TestConnection(t *testing.T) { addr := bootstrapConnections(t, numConns, func(nc net.Conn) {}) pool := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := pool.ready() assert.Nil(t, err, "pool.connect() error: %v", err) @@ -1195,7 +1048,7 @@ func (d *dialer) lenclosed() int { } type testCancellationListener struct { - listener *cancellListener + listener *contextDoneListener numListen int numStopListening int aborted bool @@ -1205,7 +1058,7 @@ type testCancellationListener struct { // returned by the StopListening method. func newTestCancellationListener(aborted bool) *testCancellationListener { return &testCancellationListener{ - listener: newCancellListener(), + listener: newContextDoneListener(), aborted: aborted, } } diff --git a/x/mongo/driver/topology/context_listener.go b/x/mongo/driver/topology/context_listener.go new file mode 100644 index 00000000000..99c252c87c9 --- /dev/null +++ b/x/mongo/driver/topology/context_listener.go @@ -0,0 +1,91 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package topology + +import ( + "context" + "errors" + "sync/atomic" +) + +type contextListener interface { + Listen(context.Context, func()) + StopListening() bool +} + +// contextDoneListener listens for context-ending eventsand notifies listeners +// via a callback function. +type contextDoneListener struct { + aborted atomic.Value + done chan struct{} + blockOnDone bool +} + +var _ contextListener = &contextDoneListener{} + +// newContextDoneListener constructs a contextDoneListener that will block +// when a context is done until StopListening is called. +func newContextDoneListener() *contextDoneListener { + return &contextDoneListener{ + done: make(chan struct{}), + blockOnDone: true, + } +} + +// newNonBlockingContextDoneLIstener constructs a contextDoneListener that +// will not block when a context is done. In this case there are two ways to +// unblock the listener: a finished context or a call to StopListening. +func newNonBlockingContextDoneListener() *contextDoneListener { + return &contextDoneListener{ + done: make(chan struct{}), + blockOnDone: false, + } +} + +// Listen blocks until the provided context is cancelled or listening is aborted +// via the StopListening function. If this detects that the context has been +// cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback +// is called to abort in-progress work. If blockOnDone is true, this function +// will block until StopListening is called, even if the context expires. +func (c *contextDoneListener) Listen(ctx context.Context, abortFn func()) { + c.aborted.Store(false) + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + c.aborted.Store(true) + + abortFn() + } + + if c.blockOnDone { + <-c.done + } + case <-c.done: + } +} + +// StopListening stops the in-progress Listen call. If blockOnDone is true, then +// this blocks if there is no in-progress Listen call. This function will return +// true if the provided abort callback was called when listening for +// cancellation on the previous context. +func (c *contextDoneListener) StopListening() bool { + if c.blockOnDone { + c.done <- struct{}{} + } else { + select { + case c.done <- struct{}{}: + default: + } + } + + if aborted := c.aborted.Load(); aborted != nil { + return aborted.(bool) + } + + return false +} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 122e13111c8..4a1b82b4313 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -78,6 +78,7 @@ type poolConfig struct { PoolMonitor *event.PoolMonitor Logger *logger.Logger handshakeErrFn func(error, uint64, *bson.ObjectID) + ConnectTimeout time.Duration } type pool struct { @@ -122,9 +123,10 @@ type pool struct { conns map[int64]*connection // conns holds all currently open connections. newConnWait wantConnQueue // newConnWait holds all wantConn requests for new connections. - idleMu sync.Mutex // idleMu guards idleConns, idleConnWait - idleConns []*connection // idleConns holds all idle connections. - idleConnWait wantConnQueue // idleConnWait holds all wantConn requests for idle connections. + idleMu sync.Mutex // idleMu guards idleConns, idleConnWait + idleConns []*connection // idleConns holds all idle connections. + idleConnWait wantConnQueue // idleConnWait holds all wantConn requests for idle connections. + connectTimeout time.Duration } // getState returns the current state of the pool. Callers must not hold the stateMu lock. @@ -221,6 +223,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { createConnectionsCond: sync.NewCond(&sync.Mutex{}), conns: make(map[int64]*connection, config.MaxPoolSize), idleConns: make([]*connection, 0, config.MaxPoolSize), + connectTimeout: config.ConnectTimeout, } // minSize must not exceed maxSize if maxSize is not 0 if pool.maxSize != 0 && pool.minSize > pool.maxSize { @@ -1108,9 +1111,26 @@ func (p *pool) createConnections(ctx context.Context, wg *sync.WaitGroup) { } start := time.Now() - // Pass the createConnections context to connect to allow pool close to cancel connection - // establishment so shutdown doesn't block indefinitely if connectTimeout=0. - err := conn.connect(ctx) + // Pass the createConnections context to connect to allow pool close to + // cancel connection establishment so shutdown doesn't block indefinitely if + // connectTimeout=0. + // + // Per the specifications, an explicit value of connectTimeout=0 means the + // timeout is "infinite". + + var cancel context.CancelFunc + + connctx := context.Background() + if p.connectTimeout != 0 { + connctx, cancel = context.WithTimeout(ctx, p.connectTimeout) + } + + err := conn.connect(connctx) + + if cancel != nil { + cancel() + } + if err != nil { w.tryDeliver(nil, err) diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 3001aa9b1b8..69a7cce7267 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -67,7 +67,8 @@ func TestPool(t *testing.T) { }) p1 := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p1.ready() noerr(t, err) @@ -92,7 +93,9 @@ func TestPool(t *testing.T) { t.Run("calling close multiple times does not panic", func(t *testing.T) { t.Parallel() - p := newPool(poolConfig{}) + p := newPool(poolConfig{ + ConnectTimeout: defaultConnectionTimeout, + }) err := p.ready() noerr(t, err) @@ -112,7 +115,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -148,7 +152,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -183,7 +188,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -229,7 +235,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -284,7 +291,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -313,7 +321,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -369,7 +378,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -407,7 +417,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -456,7 +467,9 @@ func TestPool(t *testing.T) { t.Parallel() dialErr := errors.New("create new connection error") - p := newPool(poolConfig{}, WithDialer(func(Dialer) Dialer { + p := newPool(poolConfig{ + ConnectTimeout: defaultConnectionTimeout, + }, WithDialer(func(Dialer) Dialer { return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, dialErr }) @@ -493,8 +506,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool( poolConfig{ - Address: address.Address(addr.String()), - MaxIdleTime: time.Millisecond, + Address: address.Address(addr.String()), + MaxIdleTime: time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d }), ) @@ -538,7 +552,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -565,7 +580,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -583,7 +599,9 @@ func TestPool(t *testing.T) { t.Parallel() p := newPool( - poolConfig{}, + poolConfig{ + ConnectTimeout: defaultConnectionTimeout, + }, WithHandshaker(func(Handshaker) Handshaker { return operation.NewHello() }), @@ -632,8 +650,9 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 1, + Address: address.Address(addr.String()), + MaxPoolSize: 1, + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -672,8 +691,9 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 1, + Address: address.Address(addr.String()), + MaxPoolSize: 1, + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -727,8 +747,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool( poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 2, + Address: address.Address(addr.String()), + MaxPoolSize: 2, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d }), ) @@ -794,8 +815,9 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 1, + Address: address.Address(addr.String()), + MaxPoolSize: 1, + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -834,7 +856,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -867,7 +890,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -897,7 +921,8 @@ func TestPool(t *testing.T) { }) p1 := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p1.ready() noerr(t, err) @@ -927,8 +952,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxIdleTime: 100 * time.Millisecond, + Address: address.Address(addr.String()), + MaxIdleTime: 100 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -960,9 +986,10 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MinPoolSize: 3, - MaxIdleTime: 10 * time.Millisecond, + Address: address.Address(addr.String()), + MinPoolSize: 3, + MaxIdleTime: 10 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1000,8 +1027,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MinPoolSize: 3, + Address: address.Address(addr.String()), + MinPoolSize: 3, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1024,9 +1052,10 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MinPoolSize: 20, - MaxPoolSize: 2, + Address: address.Address(addr.String()), + MinPoolSize: 20, + MaxPoolSize: 2, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1052,6 +1081,7 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), // Set the pool's maintain interval to 10ms so that it allows the test to run quickly. MaintainInterval: 10 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1102,6 +1132,7 @@ func TestPool(t *testing.T) { MinPoolSize: 3, // Set the pool's maintain interval to 10ms so that it allows the test to run quickly. MaintainInterval: 10 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 21eafd18f24..03bcc06aa90 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -29,12 +29,9 @@ type rttConfig struct { // the operation takes longer than the interval. interval time.Duration - // The timeout applied to running the "hello" operation. If the timeout is reached while running - // the operation, the RTT sample is discarded. The default is 1 minute. - timeout time.Duration - minRTTWindow time.Duration createConnectionFn func() *connection + connectTimeout time.Duration createOperationFn func(*mnet.Connection) *operation.Hello } @@ -115,7 +112,11 @@ func (r *rttMonitor) start() { for { conn := r.cfg.createConnectionFn() - err := conn.connect(r.ctx) + + ctx, cancel := context.WithTimeout(r.ctx, r.cfg.connectTimeout) + defer cancel() + + err := conn.connect(ctx) // Add an RTT sample from the new connection handshake and start a runHellos() loop if we // successfully established the new connection. Otherwise, close the connection and try to @@ -161,11 +162,7 @@ func (r *rttMonitor) runHellos(conn *connection) { // server or a proxy stops responding to requests on the RTT connection but does not close // the TCP socket, effectively creating an operation that will never complete. We expect // that "connectTimeoutMS" provides at least enough time for a single round trip. - timeout := r.cfg.timeout - if timeout <= 0 { - timeout = conn.config.connectTimeout - } - ctx, cancel := context.WithTimeout(r.ctx, timeout) + ctx, cancel := context.WithTimeout(r.ctx, r.cfg.connectTimeout) start := time.Now() iconn := mnet.NewConnection(initConnection{conn}) diff --git a/x/mongo/driver/topology/rtt_monitor_test.go b/x/mongo/driver/topology/rtt_monitor_test.go index 7abfe024fc5..f2677c89799 100644 --- a/x/mongo/driver/topology/rtt_monitor_test.go +++ b/x/mongo/driver/topology/rtt_monitor_test.go @@ -91,7 +91,8 @@ func TestRTTMonitor(t *testing.T) { return newMockSlowConn(makeHelloReply(), 10*time.Millisecond), nil }) rtt := newRTTMonitor(&rttConfig{ - interval: 10 * time.Millisecond, + interval: 10 * time.Millisecond, + connectTimeout: defaultConnectionTimeout, createConnectionFn: func() *connection { return newConnection("", WithDialer(func(Dialer) Dialer { return dialer })) }, @@ -150,7 +151,8 @@ func TestRTTMonitor(t *testing.T) { return newMockSlowConn(makeHelloReply(), 10*time.Millisecond), nil }) rtt := newRTTMonitor(&rttConfig{ - interval: 10 * time.Millisecond, + connectTimeout: defaultConnectionTimeout, + interval: 10 * time.Millisecond, createConnectionFn: func() *connection { return newConnection("", WithDialer(func(Dialer) Dialer { return dialer })) }, @@ -248,8 +250,8 @@ func TestRTTMonitor(t *testing.T) { }() rtt := newRTTMonitor(&rttConfig{ - interval: 10 * time.Millisecond, - timeout: 100 * time.Millisecond, + interval: 10 * time.Millisecond, + connectTimeout: 100 * time.Millisecond, createConnectionFn: func() *connection { return newConnection(address.Address(l.Addr().String())) }, diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 862f9c6d48e..8d53dfd62ee 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -133,16 +133,12 @@ type Server struct { currentSubscriberID uint64 subscriptionsClosed bool - // heartbeat and cancellation related fields - // globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down. - // heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be - // cancelled automatically during shutdown. - heartbeatLock sync.Mutex - conn *connection - globalCtx context.Context - globalCtxCancel context.CancelFunc - heartbeatCtx context.Context - heartbeatCtxCancel context.CancelFunc + conn *connection + + // Calling StopListening on the heartbeatListner will cancel the context + // passed to the heartbeat check. This will result in the current connection + // being closed. + heartbeatListener contextListener processErrorLock sync.Mutex rttMonitor *rttMonitor @@ -160,9 +156,10 @@ func ConnectServer( addr address.Address, updateCallback updateTopologyCallback, topologyID bson.ObjectID, + connectTimeout time.Duration, opts ...ServerOption, ) (*Server, error) { - srvr := NewServer(addr, topologyID, opts...) + srvr := NewServer(addr, topologyID, connectTimeout, opts...) err := srvr.Connect(updateCallback) if err != nil { return nil, err @@ -172,9 +169,14 @@ func ConnectServer( // NewServer creates a new server. The mongodb server at the address will be monitored // on an internal monitoring goroutine. -func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOption) *Server { - cfg := newServerConfig(opts...) - globalCtx, globalCtxCancel := context.WithCancel(context.Background()) +func NewServer( + addr address.Address, + topologyID bson.ObjectID, + connectTimeout time.Duration, + opts ...ServerOption, +) *Server { + cfg := newServerConfig(connectTimeout, opts...) + s := &Server{ state: serverDisconnected, @@ -187,9 +189,8 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt topologyID: topologyID, - subscribers: make(map[uint64]chan description.Server), - globalCtx: globalCtx, - globalCtxCancel: globalCtxCancel, + subscribers: make(map[uint64]chan description.Server), + heartbeatListener: newNonBlockingContextDoneListener(), } s.desc.Store(newDefaultServerDescription(addr)) rttCfg := &rttConfig{ @@ -197,6 +198,7 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt minRTTWindow: 5 * time.Minute, createConnectionFn: s.createConnection, createOperationFn: s.createBaseOperation, + connectTimeout: connectTimeout, } s.rttMonitor = newRTTMonitor(rttCfg) @@ -211,6 +213,7 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt PoolMonitor: cfg.poolMonitor, Logger: cfg.logger, handshakeErrFn: s.ProcessHandshakeError, + ConnectTimeout: connectTimeout, } connectionOpts := copyConnectionOpts(cfg.connectionOpts) @@ -299,13 +302,9 @@ func (s *Server) Disconnect(ctx context.Context) error { s.updateTopologyCallback.Store((updateTopologyCallback)(nil)) - // Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done - // channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end. - // The done channel is closed before cancelling the check so the update routine() will immediately detect that it - // can stop rather than trying to create new connections until the read from done succeeds. - s.globalCtxCancel() close(s.done) - s.cancelCheck() + + s.heartbeatListener.StopListening() s.pool.close(ctx) @@ -380,7 +379,7 @@ func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint6 // checking logic above has already determined that this description is not stale. s.updateDescription(newServerDescriptionFromError(s.address, wrappedConnErr, nil)) s.pool.clear(err, serviceID) - s.cancelCheck() + s.heartbeatListener.StopListening() } // Description returns a description of the server as of the last heartbeat. @@ -559,10 +558,65 @@ func (s *Server) ProcessError(err error, describer mnet.Describer) driver.Proces // updateDescription. s.updateDescription(newServerDescriptionFromError(s.address, err, nil)) s.pool.clear(err, serviceID) - s.cancelCheck() + s.heartbeatListener.StopListening() return driver.ConnectionPoolCleared } +type serverChecker interface { + check(ctx context.Context) (description.Server, error) +} + +var _ serverChecker = &Server{} + +// checkServerWithSignal will run the server heartbeat check, canceling if the +// sig channel's buffer is emptied or is closed. +func checkServerWithSignal( + checker serverChecker, + conn *connection, + listener contextListener, +) (description.Server, error) { + // Create a context for the blocking operations associated with checking the + // status of a server. + // + // The Server Monitoring spec already mandates that drivers set and + // dynamically update the read/write timeout of the dedicated connections + // used in monitoring threads, so we rely on that to time out commands + // rather than adding complexity to the behavior of timeoutMS. + ctx, cancel := context.WithCancel(context.Background()) + + defer listener.StopListening() + defer cancel() + + go func(conn *connection) { + defer cancel() + + var aborted bool + listener.Listen(ctx, func() { + aborted = true + }) + + // Close the connection if the listener was stopped before + // checkServerWithSignal terminates. + if !aborted { + if conn == nil { + return + } + + // If the connection exists, we need to wait for it to be connected + // because conn.connect() and conn.close() cannot be called concurrently. + // If the connection wasn't successfully opened, its state was set back + // to disconnected, so calling conn.close() will be a no-op. + conn.closeConnectContext() + conn.wait() + conn.prevCanceled.Store(true) + _ = conn.close() + } + + }(conn) + + return checker.check(ctx) +} + // update handle performing heartbeats and updating any subscribers of the // newest description.Server retrieved. func (s *Server) update() { @@ -587,8 +641,6 @@ func (s *Server) update() { s.subscriptionsClosed = true s.subLock.Unlock() - // We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks - // below detect that the server is being closed, so we can be sure that the connection isn't being used. if s.conn != nil { _ = s.conn.close() } @@ -626,8 +678,9 @@ func (s *Server) update() { previousDescription := s.Description() - // Perform the next check. - desc, err := s.check() + desc, err := checkServerWithSignal(s, s.conn, s.heartbeatListener) + + // The only error returned from checkServerWithSignal is errCheckCancelled. if errors.Is(err, errCheckCancelled) { if atomic.LoadInt64(&s.state) != serverConnected { continue @@ -754,11 +807,6 @@ func (s *Server) updateDescription(desc description.Server) { func (s *Server) createConnection() *connection { opts := copyConnectionOpts(s.cfg.connectionOpts) opts = append(opts, - WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - // We override whatever handshaker is currently attached to the options with a basic - // one because need to make sure we don't do auth. WithHandshaker(func(h Handshaker) Handshaker { return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts). ServerAPI(s.cfg.serverAPI) @@ -776,48 +824,19 @@ func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption { return optsCopy } -func (s *Server) setupHeartbeatConnection() error { +func (s *Server) setupHeartbeatConnection(ctx context.Context) error { conn := s.createConnection() - // Take the lock when assigning the context and connection because they're accessed by cancelCheck. - s.heartbeatLock.Lock() - if s.heartbeatCtxCancel != nil { - // Ensure the previous context is cancelled to avoid a leak. - s.heartbeatCtxCancel() - } - s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx) s.conn = conn - s.heartbeatLock.Unlock() - - return s.conn.connect(s.heartbeatCtx) -} - -// cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server. -func (s *Server) cancelCheck() { - var conn *connection - // Take heartbeatLock for mutual exclusion with the checks in the update function. - s.heartbeatLock.Lock() - if s.heartbeatCtx != nil { - s.heartbeatCtxCancel() - } - conn = s.conn - s.heartbeatLock.Unlock() + if s.cfg.connectTimeout != 0 { + var cancelFn context.CancelFunc + ctx, cancelFn = context.WithTimeout(ctx, s.cfg.connectTimeout) - if conn == nil { - return + defer cancelFn() } - // If the connection exists, we need to wait for it to be connected because conn.connect() and - // conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its - // state was set back to disconnected, so calling conn.close() will be a no-op. - conn.closeConnectContext() - conn.wait() - _ = conn.close() -} - -func (s *Server) checkWasCancelled() bool { - return s.heartbeatCtx.Err() != nil + return s.conn.connect(ctx) } func (s *Server) createBaseOperation(conn *mnet.Connection) *operation.Hello { @@ -843,24 +862,119 @@ func isStreamable(srv *Server) bool { return srv.Description().Kind != description.Unknown && srv.Description().TopologyVersion != nil } -func (s *Server) check() (description.Server, error) { +func (s *Server) streamable() bool { + return isStreamingEnabled(s) && isStreamable(s) +} + +// getHeartbeatTimeout will return the maximum allowable duration for streaming +// or polling a hello command during server monitoring. +func getHeartbeatTimeout(srv *Server) time.Duration { + if srv.conn.getCurrentlyStreaming() || srv.streamable() { + // If connectTimeoutMS=0, the operation timeout should be infinite. + // Otherwise, it is connectTimeoutMS + heartbeatFrequencyMS to account for + // the fact that the query will block for heartbeatFrequencyMS + // server-side. + streamingTO := srv.cfg.connectTimeout + if streamingTO != 0 { + streamingTO += srv.cfg.heartbeatInterval + } + + return streamingTO + } + + // The server doesn't support the awaitable protocol. Set the timeout to + // connectTimeoutMS and execute a regular heartbeat without any additional + // parameters. + return srv.cfg.connectTimeout +} + +// withHeartbeatTimeout will apply the appropriate timeout to the parent context +// for server monitoring. +func withHeartbeatTimeout(parent context.Context, srv *Server) (context.Context, context.CancelFunc) { + var cancel context.CancelFunc + + timeout := getHeartbeatTimeout(srv) + if timeout == 0 { + return parent, cancel + } + + return context.WithTimeout(parent, timeout) +} + +// doHandshake will construct the hello operation use to execute a handshake +// between the client and a server. Depending on the configuration and version, +// this function will either poll, stream, or resume streaming. +func doHandshake(ctx context.Context, srv *Server) (description.Server, error) { + heartbeatConn := mnet.NewConnection(initConnection{srv.conn}) + handshakeOp := srv.createBaseOperation(heartbeatConn) + + // Using timeoutMS in the monitoring and RTT calculation threads would require + // another special case in the code that derives maxTimeMS from timeoutMS + // because the awaitable hello requests sent to 4.4+ servers already have a + // maxAwaitTimeMS field. Adding maxTimeMS also does not help for non-awaitable + // hello commands because we expect them to execute quickly on the server. The + // Server Monitoring spec already mandates that drivers set and dynamically + // update the read/write timeout of the dedicated connections used in + // monitoring threads, so we rely on that to time out commands rather than + // adding complexity to the behavior of timeoutMS. + handshakeOp = handshakeOp.OmitMaxTimeMS(true) + + // Apply monitoring timeout. + ctx, cancel := withHeartbeatTimeout(ctx, srv) + defer cancel() + + // If we are currently streaming, get more data and return the result. + if srv.conn.getCurrentlyStreaming() { + if err := handshakeOp.StreamResponse(ctx, heartbeatConn); err != nil { + return description.Server{}, err + } + + return handshakeOp.Result(srv.address), nil + } + + // If the server supports streaming, update it so the next handshake streams + // the response. + if srv.streamable() { + srv.conn.setCanStream(true) + + maxAwaitTimeMS := int64(srv.cfg.heartbeatInterval) / 1e6 + + handshakeOp = handshakeOp. + TopologyVersion(srv.Description().TopologyVersion). + MaxAwaitTimeMS(maxAwaitTimeMS) + } + + // Perform the handshake. + if err := handshakeOp.Execute(ctx); err != nil { + return description.Server{}, err + } + + return handshakeOp.Result(srv.address), nil +} + +func (s *Server) check(ctx context.Context) (description.Server, error) { var descPtr *description.Server var err error - var duration time.Duration + var execDuration time.Duration start := time.Now() + var previousCanceled bool + if s.conn != nil { + previousCanceled = s.conn.previousCanceled() + } + // Create a new connection if this is the first check, the connection was closed after an error during the previous // check, or the previous check was cancelled. - if s.conn == nil || s.conn.closed() || s.checkWasCancelled() { + if s.conn == nil || s.conn.closed() || previousCanceled { connID := "0" if s.conn != nil { connID = s.conn.ID() } s.publishServerHeartbeatStartedEvent(connID, false) // Create a new connection and add it's handshake RTT as a sample. - err = s.setupHeartbeatConnection() - duration = time.Since(start) + err = s.setupHeartbeatConnection(ctx) + execDuration = time.Since(start) connID = "0" if s.conn != nil { connID = s.conn.ID() @@ -869,80 +983,47 @@ func (s *Server) check() (description.Server, error) { // Use the description from the connection handshake as the value for this check. s.rttMonitor.addSample(s.conn.helloRTT) descPtr = &s.conn.desc - s.publishServerHeartbeatSucceededEvent(connID, duration, s.conn.desc, false) + s.publishServerHeartbeatSucceededEvent(connID, execDuration, s.conn.desc, false) } else { err = unwrapConnectionError(err) - s.publishServerHeartbeatFailedEvent(connID, duration, err, false) + s.publishServerHeartbeatFailedEvent(connID, execDuration, err, false) } } else { - // An existing connection is being used. Use the server description properties to execute the right heartbeat. - - // Wrap conn in a type that implements driver.StreamerConnection. - iconn := initConnection{s.conn} - heartbeatConn := mnet.NewConnection(iconn) + // An existing connection is being used. Use the server description + // properties to execute the right heartbeat. - baseOperation := s.createBaseOperation(heartbeatConn) - previousDescription := s.Description() streamable := isStreamingEnabled(s) && isStreamable(s) s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) - switch { - case s.conn.getCurrentlyStreaming(): - // The connection is already in a streaming state, so we stream the next response. - err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn) - case streamable: - // The server supports the streamable protocol. Set the socket timeout to - // connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so - // the wire message will advertise streaming support to the server. - - // Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13). - maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6 - // If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS + - // heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS - // server-side. - socketTimeout := s.cfg.heartbeatTimeout - if socketTimeout != 0 { - socketTimeout += s.cfg.heartbeatInterval - } - s.conn.setSocketTimeout(socketTimeout) - baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion). - MaxAwaitTimeMS(maxAwaitTimeMS) - s.conn.setCanStream(true) - err = baseOperation.Execute(s.heartbeatCtx) - default: - // The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and - // execute a regular heartbeat without any additional parameters. - - s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) - err = baseOperation.Execute(s.heartbeatCtx) - } + var tempDesc description.Server + tempDesc, err = doHandshake(ctx, s) // Perform a handshake with the server - duration = time.Since(start) + execDuration = time.Since(start) // We need to record an RTT sample in the polling case so that if the server // is < 4.4, or if polling is specified by the user, then the // RTT-short-circuit feature of CSOT is not disabled. if !streamable { - s.rttMonitor.addSample(duration) + s.rttMonitor.addSample(execDuration) } if err == nil { - tempDesc := baseOperation.Result(s.address) descPtr = &tempDesc - s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, tempDesc, s.conn.getCurrentlyStreaming() || streamable) + s.publishServerHeartbeatSucceededEvent(s.conn.ID(), execDuration, + tempDesc, s.conn.getCurrentlyStreaming() || streamable) } else { // Close the connection here rather than below so we ensure we're not closing a connection that wasn't // successfully created. if s.conn != nil { _ = s.conn.close() } - s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, s.conn.getCurrentlyStreaming() || streamable) + s.publishServerHeartbeatFailedEvent(s.conn.ID(), execDuration, err, s.conn.getCurrentlyStreaming() || streamable) } } if descPtr != nil { - // The check was successful. Set the average RTT and the 90th percentile RTT and return. + // The check was successful. Set the average RTT and return. desc := *descPtr desc.AverageRTT = s.rttMonitor.EWMA() desc.AverageRTTSet = true @@ -951,7 +1032,7 @@ func (s *Server) check() (description.Server, error) { return desc, nil } - if s.checkWasCancelled() { + if previousCanceled { // If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller // will know that an actual error didn't occur. return emptyDescription, errCheckCancelled diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index c02600e232a..bfd1218d121 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -25,7 +25,7 @@ type serverConfig struct { connectionOpts []ConnectionOption appname string heartbeatInterval time.Duration - heartbeatTimeout time.Duration + connectTimeout time.Duration serverMonitoringMode string serverMonitor *event.ServerMonitor registry *bson.Registry @@ -43,10 +43,10 @@ type serverConfig struct { poolMaintainInterval time.Duration } -func newServerConfig(opts ...ServerOption) *serverConfig { +func newServerConfig(connectTimeout time.Duration, opts ...ServerOption) *serverConfig { cfg := &serverConfig{ heartbeatInterval: 10 * time.Second, - heartbeatTimeout: 10 * time.Second, + connectTimeout: connectTimeout, registry: defaultRegistry, } @@ -65,8 +65,8 @@ type ServerOption func(*serverConfig) // ServerAPIFromServerOptions will return the server API options if they have been functionally set on the ServerOption // slice. -func ServerAPIFromServerOptions(opts []ServerOption) *driver.ServerAPIOptions { - return newServerConfig(opts...).serverAPI +func ServerAPIFromServerOptions(connectTimeout time.Duration, opts []ServerOption) *driver.ServerAPIOptions { + return newServerConfig(connectTimeout, opts...).serverAPI } func withMonitoringDisabled(fn func(bool) bool) ServerOption { @@ -103,14 +103,6 @@ func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption { } } -// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to -// connection. -func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption { - return func(cfg *serverConfig) { - cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout) - } -} - // WithMaxConnections configures the maximum number of connections to allow for // a given server. If max is 0, then maximum connection pool size is not limited. func WithMaxConnections(fn func(uint64) uint64) ServerOption { diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 1fd20fdadbf..8b18d2408cc 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -33,6 +33,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" @@ -166,6 +167,7 @@ func TestServerHeartbeatTimeout(t *testing.T) { server := NewServer( address.Address("localhost:27017"), bson.NewObjectID(), + defaultConnectionTimeout, WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return tpm.PoolMonitor }), @@ -218,6 +220,7 @@ func TestServerConnectionTimeout(t *testing.T) { desc: "successful connection should not clear the pool", expectErr: false, expectPoolCleared: false, + connectTimeout: defaultConnectionTimeout, }, { desc: "timeout error during dialing should clear the pool", @@ -262,6 +265,7 @@ func TestServerConnectionTimeout(t *testing.T) { }, expectErr: true, expectPoolCleared: true, + connectTimeout: defaultConnectionTimeout, }, { desc: "operation context timeout with unrelated dial errors should clear the pool", @@ -300,15 +304,13 @@ func TestServerConnectionTimeout(t *testing.T) { server := NewServer( address.Address(l.Addr().String()), bson.NewObjectID(), + tc.connectTimeout, WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return tpm.PoolMonitor }), // Replace the default dialer and handshaker with the test dialer and handshaker, if // present. WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption { - if tc.connectTimeout > 0 { - opts = append(opts, WithConnectTimeout(func(time.Duration) time.Duration { return tc.connectTimeout })) - } if tc.dialer != nil { opts = append(opts, WithDialer(tc.dialer)) } @@ -381,6 +383,7 @@ func TestServer(t *testing.T) { s := NewServer( address.Address("localhost"), bson.NewObjectID(), + defaultConnectionTimeout, WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption { return append(connOpts, WithHandshaker(func(Handshaker) Handshaker { @@ -567,7 +570,13 @@ func TestServer(t *testing.T) { WithMaxConnecting(func(uint64) uint64 { return 1 }), } - server, err := ConnectServer(address.Address("localhost:27017"), nil, bson.NewObjectID(), serverOpts...) + server, err := ConnectServer( + address.Address("localhost:27017"), + nil, + bson.NewObjectID(), + defaultConnectionTimeout, + serverOpts..., + ) assert.Nil(t, err, "ConnectServer error: %v", err) defer func() { _ = server.Disconnect(context.Background()) @@ -601,6 +610,7 @@ func TestServer(t *testing.T) { d := newdialer(&net.Dialer{}) s := NewServer(address.Address(addr.String()), bson.NewObjectID(), + defaultConnectionTimeout, WithConnectionOptions(func(option ...ConnectionOption) []ConnectionOption { return []ConnectionOption{WithDialer(func(_ Dialer) Dialer { return d })} }), @@ -648,7 +658,14 @@ func TestServer(t *testing.T) { updated.Store(true) return desc } - s, err := ConnectServer(address.Address("localhost"), updateCallback, bson.NewObjectID()) + + s, err := ConnectServer( + address.Address("localhost"), + updateCallback, + bson.NewObjectID(), + defaultConnectionTimeout, + ) + require.NoError(t, err) s.updateDescription(description.Server{Addr: s.address}) require.True(t, updated.Load().(bool)) @@ -663,10 +680,10 @@ func TestServer(t *testing.T) { return append(connOpts, dialerOpt) }) - s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), serverOpt) + s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), defaultConnectionTimeout, serverOpt) // do a heartbeat with a nil connection so a new one will be dialed - _, err := s.check() + _, err := s.check(context.Background()) assert.Nil(t, err, "check error: %v", err) assert.NotNil(t, s.conn, "no connection dialed in check") @@ -683,7 +700,7 @@ func TestServer(t *testing.T) { if err = channelConn.AddResponse(makeHelloReply()); err != nil { t.Fatalf("error adding response: %v", err) } - _, err = s.check() + _, err = s.check(context.Background()) assert.Nil(t, err, "check error: %v", err) wm = channelConn.GetWrittenMessage() @@ -727,10 +744,10 @@ func TestServer(t *testing.T) { WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return sdam }), } - s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), serverOpts...) + s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), defaultConnectionTimeout, serverOpts...) // set up heartbeat connection, which doesn't send events - _, err := s.check() + _, err := s.check(context.Background()) assert.Nil(t, err, "check error: %v", err) channelConn := s.conn.nc.(*drivertest.ChannelNetConn) @@ -742,7 +759,7 @@ func TestServer(t *testing.T) { if err = channelConn.AddResponse(makeHelloReply()); err != nil { t.Fatalf("error adding response: %v", err) } - _, err = s.check() + _, err = s.check(context.Background()) _ = channelConn.GetWrittenMessage() assert.Nil(t, err, "check error: %v", err) @@ -764,7 +781,7 @@ func TestServer(t *testing.T) { // do a heartbeat with a non-nil connection readErr := errors.New("error") channelConn.ReadErr <- readErr - _, err = s.check() + _, err = s.check(context.Background()) _ = channelConn.GetWrittenMessage() assert.Nil(t, err, "check error: %v", err) @@ -787,65 +804,10 @@ func TestServer(t *testing.T) { s := NewServer(address.Address("localhost"), bson.NewObjectID(), + defaultConnectionTimeout, WithServerAppName(func(string) string { return name })) require.Equal(t, name, s.cfg.appname, "expected appname to be: %v, got: %v", name, s.cfg.appname) }) - t.Run("createConnection overwrites WithSocketTimeout", func(t *testing.T) { - socketTimeout := 40 * time.Second - - s := NewServer( - address.Address("localhost"), - bson.NewObjectID(), - WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption { - return append( - connOpts, - WithReadTimeout(func(time.Duration) time.Duration { return socketTimeout }), - WithWriteTimeout(func(time.Duration) time.Duration { return socketTimeout }), - ) - }), - ) - - conn := s.createConnection() - assert.Equal(t, s.cfg.heartbeatTimeout, 10*time.Second, "expected heartbeatTimeout to be: %v, got: %v", 10*time.Second, s.cfg.heartbeatTimeout) - assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout) - assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout) - }) - t.Run("heartbeat contexts are not leaked", func(t *testing.T) { - // The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks. - - server, err := ConnectServer( - address.Address("invalid"), - nil, - bson.NewObjectID(), - withMonitoringDisabled(func(bool) bool { - return true - }), - ) - assert.Nil(t, err, "ConnectServer error: %v", err) - - // Expect check to return an error in the server description because the server address doesn't exist. This is - // OK because we just want to ensure the heartbeat context is created. - desc, err := server.check() - assert.Nil(t, err, "check error: %v", err) - assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") - assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil") - assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err()) - - // Override heartbeatCtxCancel with a wrapper that records whether or not it was called. - oldCancelFn := server.heartbeatCtxCancel - var previousCtxCancelled bool - server.heartbeatCtxCancel = func() { - previousCtxCancelled = true - oldCancelFn() - } - - // The second check call should attempt to create a new heartbeat connection and should cancel the previous - // heartbeatCtx during the process. - desc, err = server.check() - assert.Nil(t, err, "check error: %v", err) - assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") - assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not") - }) } func TestServer_ProcessError(t *testing.T) { @@ -1188,7 +1150,7 @@ func TestServer_ProcessError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - server := NewServer(address.Address(""), bson.NewObjectID()) + server := NewServer(address.Address(""), bson.NewObjectID(), defaultConnectionTimeout) server.state = serverConnected err := server.pool.ready() require.Nil(t, err, "pool.ready() error: %v", err) @@ -1213,6 +1175,82 @@ func TestServer_ProcessError(t *testing.T) { } } +func TestServer_getSocketTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + enableStreaming bool + connectTimeout time.Duration + heartbeatInterval time.Duration + want time.Duration + }{ + { + name: "server is streamable with connectTimeout and no heartbeat interval", + enableStreaming: true, + connectTimeout: 1, + heartbeatInterval: 0, + want: 1, + }, + { + name: "server is streamable with connectTimeout and heartbeat interval", + enableStreaming: true, + connectTimeout: 1, + heartbeatInterval: 1, + want: 2, + }, + { + name: "server is streamable with no connectTimeout and heartbeat interval", + enableStreaming: true, + connectTimeout: 0, + heartbeatInterval: 1, + want: 0, + }, + { + name: "server is streamable with no connectTimeout and no heartbeat interval", + enableStreaming: true, + connectTimeout: 0, + heartbeatInterval: 0, + want: 0, + }, + { + name: "server is not streamable", + enableStreaming: false, + connectTimeout: 1, + heartbeatInterval: 0, + want: 1, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + srv := &Server{ + cfg: &serverConfig{ + connectTimeout: test.connectTimeout, + heartbeatInterval: test.heartbeatInterval, + }, + conn: &connection{}, + } + + srv.desc.Store(description.Server{ + Kind: description.ServerKind(description.TopologyKindReplicaSet), + TopologyVersion: &description.TopologyVersion{}, + }) + + if test.enableStreaming { + srv.cfg.serverMonitoringMode = connstring.ServerMonitoringModeStream + } + + got := getHeartbeatTimeout(srv) + assert.Equal(t, test.want, got) + }) + } +} + // includesClientMetadata will return true if the wire message includes the // "client" field. func includesClientMetadata(t *testing.T, wm []byte) bool { @@ -1303,3 +1341,46 @@ func newServerDescription( LastError: lastError, } } + +type mockServerChecker struct { + sleep time.Duration +} + +var _ serverChecker = &mockServerChecker{} + +func (checker *mockServerChecker) check(ctx context.Context) (description.Server, error) { + select { + case <-ctx.Done(): + return description.Server{}, ctx.Err() + case <-time.After(checker.sleep): + } + + return description.Server{}, nil +} + +func TestCheckServerWithSignal(t *testing.T) { + t.Run("check finishes before signal", func(t *testing.T) { + listener := newNonBlockingContextDoneListener() + go func() { + defer listener.StopListening() + + time.Sleep(105 * time.Millisecond) + }() + + _, err := checkServerWithSignal(&mockServerChecker{sleep: 100 * time.Millisecond}, &connection{}, listener) + assert.NoError(t, err) + }) + + t.Run("check finishes after signal", func(t *testing.T) { + listener := newNonBlockingContextDoneListener() + go func() { + defer listener.StopListening() + + time.Sleep(100 * time.Millisecond) + }() + + _, err := checkServerWithSignal(&mockServerChecker{sleep: 1 * time.Second}, &connection{}, listener) + assert.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) +} diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index d9e9de1f50b..60077cea856 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -63,10 +63,6 @@ var ErrTopologyClosed = errors.New("topology is closed") // already connected Topology. var ErrTopologyConnected = errors.New("topology is connected or connecting") -// ErrServerSelectionTimeout is returned from server selection when the server -// selection process took longer than allowed by the timeout. -var ErrServerSelectionTimeout = errors.New("server selection timeout") - // MonitorMode represents the way in which a server is monitored. type MonitorMode uint8 @@ -126,18 +122,6 @@ var ( _ driver.Subscriber = &Topology{} ) -type serverSelectionState struct { - selector description.ServerSelector - timeoutChan <-chan time.Time -} - -func newServerSelectionState(selector description.ServerSelector, timeoutChan <-chan time.Time) serverSelectionState { - return serverSelectionState{ - selector: selector, - timeoutChan: timeoutChan, - } -} - // New creates a new topology. A "nil" config is interpreted as the default configuration. func New(cfg *Config) (*Topology, error) { if cfg == nil { @@ -503,9 +487,8 @@ func (t *Topology) RequestImmediateCheck() { t.serversLock.Unlock() } -// SelectServer selects a server with given a selector. SelectServer complies with the -// server selection spec, and will time out after serverSelectionTimeout or when the -// parent context is done. +// SelectServer selects a server with given a selector, returning the remaining +// computedServerSelectionTimeout. func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) { if atomic.LoadInt64(&t.state) != topologyConnected { if mustLogServerSelection(t, logger.LevelDebug) { @@ -514,17 +497,9 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect return nil, ErrTopologyClosed } - var ssTimeoutCh <-chan time.Time - - if t.cfg.ServerSelectionTimeout > 0 { - ssTimeout := time.NewTimer(t.cfg.ServerSelectionTimeout) - ssTimeoutCh = ssTimeout.C - defer ssTimeout.Stop() - } var doneOnce bool var sub *driver.Subscription - selectionState := newServerSelectionState(ss, ssTimeoutCh) // Record the start time. startTime := time.Now() @@ -539,7 +514,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect // for the first pass, select a server from the current description. // this improves selection speed for up-to-date topology descriptions. - suitable, selectErr = t.selectServerFromDescription(t.Description(), selectionState) + suitable, selectErr = t.selectServerFromDescription(t.Description(), ss) doneOnce = true } else { // if the first pass didn't select a server, the previous description did not contain a suitable server, so @@ -557,7 +532,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect defer func() { _ = t.Unsubscribe(sub) }() } - suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, selectionState) + suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, ss) } if selectErr != nil { if mustLogServerSelection(t, logger.LevelDebug) { @@ -704,20 +679,22 @@ func (t *Topology) FindServer(selected description.Server) (*SelectedServer, err // selectServerFromSubscription loops until a topology description is available for server selection. It returns // when the given context expires, server selection timeout is reached, or a description containing a selectable // server is available. -func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptionCh <-chan description.Topology, - selectionState serverSelectionState) ([]description.Server, error) { +func (t *Topology) selectServerFromSubscription( + ctx context.Context, + subscriptionCh <-chan description.Topology, + srvSelector description.ServerSelector, +) ([]description.Server, error) { current := t.Description() for { select { case <-ctx.Done(): return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current} - case <-selectionState.timeoutChan: - return nil, ServerSelectionError{Wrapped: ErrServerSelectionTimeout, Desc: current} case current = <-subscriptionCh: + default: } - suitable, err := t.selectServerFromDescription(current, selectionState) + suitable, err := t.selectServerFromDescription(current, srvSelector) if err != nil { return nil, err } @@ -730,8 +707,10 @@ func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptio } // selectServerFromDescription process the given topology description and returns a slice of suitable servers. -func (t *Topology) selectServerFromDescription(desc description.Topology, - selectionState serverSelectionState) ([]description.Server, error) { +func (t *Topology) selectServerFromDescription( + desc description.Topology, + srvSelector description.ServerSelector, +) ([]description.Server, error) { // Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because // selecting a server from a description is not a blocking operation. @@ -759,7 +738,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, allowed[i] = desc.Servers[idx] } - suitable, err := selectionState.selector.SelectServer(desc, allowed) + suitable, err := srvSelector.SelectServer(desc, allowed) if err != nil { return nil, ServerSelectionError{Wrapped: err, Desc: desc} } @@ -769,7 +748,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, func (t *Topology) pollSRVRecords(hosts string) { defer t.pollingwg.Done() - serverConfig := newServerConfig(t.cfg.ServerOpts...) + serverConfig := newServerConfig(t.cfg.ConnectTimeout, t.cfg.ServerOpts...) heartbeatInterval := serverConfig.heartbeatInterval pollTicker := time.NewTicker(t.rescanSRVInterval) @@ -992,7 +971,7 @@ func (t *Topology) addServer(addr address.Address) error { return nil } - svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ServerOpts...) + svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ConnectTimeout, t.cfg.ServerOpts...) if err != nil { return err } @@ -1104,6 +1083,16 @@ func (t *Topology) publishTopologyClosedEvent() { } } +// GetServerSelectionTimeout returns the server selection timeout defined on +// the client options. +func (t *Topology) GetServerSelectionTimeout() time.Duration { + if t.cfg == nil { + return 0 + } + + return t.cfg.ServerSelectionTimeout +} + func newEventServerDescription(srv description.Server) event.ServerDescription { evtSrv := event.ServerDescription{ Addr: srv.Addr, diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index c959fe5cf92..612735bd3b0 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -47,15 +47,20 @@ func TestTopologyErrors(t *testing.T) { assert.Nil(t, err, "error creating topology: %v", err) var serverSelectionErr error - callback := func(ctx context.Context) { - selectServerCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + callback := func() bool { + selectServerCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - state := newServerSelectionState(selectNone, make(<-chan time.Time)) subCh := make(<-chan description.Topology) - _, serverSelectionErr = topo.selectServerFromSubscription(selectServerCtx, subCh, state) + _, serverSelectionErr = topo.selectServerFromSubscription(selectServerCtx, subCh, selectNone) + return true } - assert.Soon(t, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected context deadline to fail within 150ms") + assert.True(t, errors.Is(serverSelectionErr, context.DeadlineExceeded), "expected %v, received %v", context.DeadlineExceeded, serverSelectionErr) }) diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 705ec3f7e15..dd35ec80fbd 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -24,6 +24,7 @@ import ( ) const defaultServerSelectionTimeout = 30 * time.Second +const defaultConnectionTimeout = 30 * time.Second // Config is used to construct a topology. type Config struct { @@ -32,6 +33,8 @@ type Config struct { SeedList []string ServerOpts []ServerOption URI string + ConnectTimeout time.Duration + Timeout *time.Duration ServerSelectionTimeout time.Duration ServerMonitor *event.ServerMonitor SRVMaxHosts int @@ -82,11 +85,16 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, var connOpts []ConnectionOption var serverOpts []ServerOption - cfgp := &Config{} + cfgp := &Config{ + Timeout: co.Timeout, + } // Set the default "ServerSelectionTimeout" to 30 seconds. cfgp.ServerSelectionTimeout = defaultServerSelectionTimeout + // Set the default "ConnectionTimeout" to 30 seconds. + cfgp.ConnectTimeout = defaultConnectionTimeout + // Set the default "SeedList" to localhost. cfgp.SeedList = []string{"localhost:27017"} @@ -204,15 +212,7 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, } } connOpts = append(connOpts, WithHandshaker(handshaker)) - // ConnectTimeout - if co.ConnectTimeout != nil { - serverOpts = append(serverOpts, WithHeartbeatTimeout( - func(time.Duration) time.Duration { return *co.ConnectTimeout }, - )) - connOpts = append(connOpts, WithConnectTimeout( - func(time.Duration) time.Duration { return *co.ConnectTimeout }, - )) - } + // Dialer if co.Dialer != nil { connOpts = append(connOpts, WithDialer( @@ -292,13 +292,9 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, if co.ServerSelectionTimeout != nil { cfgp.ServerSelectionTimeout = *co.ServerSelectionTimeout } - // SocketTimeout - if co.SocketTimeout != nil { - connOpts = append( - connOpts, - WithReadTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }), - WithWriteTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }), - ) + //ConnectionTimeout + if co.ConnectTimeout != nil { + cfgp.ConnectTimeout = *co.ConnectTimeout } // TLSConfig if co.TLSConfig != nil { diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index e57c75bcb00..1b6140f5dca 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -73,7 +73,7 @@ func TestLoadBalancedFromConnString(t *testing.T) { assert.Nil(t, err, "topology.New error: %v", err) assert.Equal(t, tc.loadBalanced, topo.cfg.LoadBalanced, "expected loadBalanced %v, got %v", tc.loadBalanced, topo.cfg.LoadBalanced) - srvr := NewServer("", topo.id, topo.cfg.ServerOpts...) + srvr := NewServer("", topo.id, defaultConnectionTimeout, topo.cfg.ServerOpts...) assert.Equal(t, tc.loadBalanced, srvr.cfg.loadBalanced, "expected loadBalanced %v, got %v", tc.loadBalanced, srvr.cfg.loadBalanced) conn := newConnection("", srvr.cfg.connectionOpts...) diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index 937824d4dd5..0e4920d88bf 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -9,7 +9,6 @@ package topology import ( "context" "encoding/json" - "errors" "fmt" "io/ioutil" "path" @@ -25,7 +24,6 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" - "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) @@ -65,10 +63,6 @@ func TestServerSelection(t *testing.T) { var selectNone serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { return []description.Server{}, nil } - var errSelectionError = errors.New("encountered an error in the selector") - var selectError serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { - return nil, errSelectionError - } t.Run("Success", func(t *testing.T) { topo, err := New(nil) @@ -83,8 +77,7 @@ func TestServerSelection(t *testing.T) { subCh := make(chan description.Topology, 1) subCh <- desc - state := newServerSelectionState(selectFirst, nil) - srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) + srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, selectFirst) noerr(t, err) if len(srvs) != 1 { t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) @@ -148,8 +141,7 @@ func TestServerSelection(t *testing.T) { resp := make(chan []description.Server) go func() { - state := newServerSelectionState(selectFirst, nil) - srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) + srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, selectFirst) noerr(t, err) resp <- srvs }() @@ -196,8 +188,7 @@ func TestServerSelection(t *testing.T) { resp := make(chan error) ctx, cancel := context.WithCancel(context.Background()) go func() { - state := newServerSelectionState(selectNone, nil) - _, err := topo.selectServerFromSubscription(ctx, subCh, state) + _, err := topo.selectServerFromSubscription(ctx, subCh, selectNone) resp <- err }() @@ -218,77 +209,11 @@ func TestServerSelection(t *testing.T) { want := ServerSelectionError{Wrapped: context.Canceled, Desc: desc} assert.Equal(t, err, want, "Incorrect error received. got %v; want %v", err, want) }) - t.Run("Timeout", func(t *testing.T) { - desc := description.Topology{ - Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, - }, - } - topo, err := New(nil) - noerr(t, err) - subCh := make(chan description.Topology, 1) - subCh <- desc - resp := make(chan error) - timeout := make(chan time.Time) - go func() { - state := newServerSelectionState(selectNone, timeout) - _, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - resp <- err - }() - - select { - case err := <-resp: - t.Errorf("Received error from server selection too soon: %v", err) - case timeout <- time.Now(): - } - - select { - case err = <-resp: - case <-time.After(100 * time.Millisecond): - t.Errorf("Timed out while trying to retrieve selected servers") - } - - if err == nil { - t.Fatalf("did not receive error from server selection") - } - }) - t.Run("Error", func(t *testing.T) { - desc := description.Topology{ - Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, - }, - } - topo, err := New(nil) - noerr(t, err) - subCh := make(chan description.Topology, 1) - subCh <- desc - resp := make(chan error) - timeout := make(chan time.Time) - go func() { - state := newServerSelectionState(selectError, timeout) - _, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - resp <- err - }() - - select { - case err = <-resp: - case <-time.After(100 * time.Millisecond): - t.Errorf("Timed out while trying to retrieve selected servers") - } - - if err == nil { - t.Fatalf("did not receive error from server selection") - } - }) t.Run("findServer returns topology kind", func(t *testing.T) { topo, err := New(nil) noerr(t, err) atomic.StoreInt64(&topo.state, topologyConnected) - srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id) + srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id, defaultConnectionTimeout) noerr(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) @@ -303,71 +228,6 @@ func TestServerSelection(t *testing.T) { t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.TopologyKindSingle) } }) - t.Run("Update on not primary error", func(t *testing.T) { - topo, err := New(nil) - noerr(t, err) - atomic.StoreInt64(&topo.state, topologyConnected) - - addr1 := address.Address("one") - addr2 := address.Address("two") - addr3 := address.Address("three") - desc := description.Topology{ - Servers: []description.Server{ - {Addr: addr1, Kind: description.ServerKindRSPrimary}, - {Addr: addr2, Kind: description.ServerKindRSSecondary}, - {Addr: addr3, Kind: description.ServerKindRSSecondary}, - }, - } - - // manually add the servers to the topology - for _, srv := range desc.Servers { - s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) - noerr(t, err) - topo.servers[srv.Addr] = s - } - - // Send updated description - desc = description.Topology{ - Servers: []description.Server{ - {Addr: addr1, Kind: description.ServerKindRSSecondary}, - {Addr: addr2, Kind: description.ServerKindRSPrimary}, - {Addr: addr3, Kind: description.ServerKindRSSecondary}, - }, - } - - subCh := make(chan description.Topology, 1) - subCh <- desc - - // send a not primary error to the server forcing an update - serv, err := topo.FindServer(desc.Servers[0]) - noerr(t, err) - atomic.StoreInt64(&serv.state, serverConnected) - _ = serv.ProcessError(driver.Error{Message: driver.LegacyNotPrimaryErrMsg}, initConnection{}) - - resp := make(chan []description.Server) - - go func() { - // server selection should discover the new topology - state := newServerSelectionState(&serverselector.Write{}, nil) - srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) - resp <- srvs - }() - - var srvs []description.Server - select { - case srvs = <-resp: - case <-time.After(100 * time.Millisecond): - t.Errorf("Timed out while trying to retrieve selected servers") - } - - if len(srvs) != 1 { - t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) - } - if srvs[0].Addr != desc.Servers[1].Addr { - t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[1].Addr) - } - }) t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) { // Assert that the server selection fast path does not create a Subscription or check for timeout errors. topo, err := New(nil) @@ -382,7 +242,7 @@ func TestServerSelection(t *testing.T) { } topo.desc.Store(desc) for _, srv := range desc.Servers { - s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) + s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id, defaultConnectionTimeout) noerr(t, err) topo.servers[srv.Addr] = s } @@ -983,6 +843,7 @@ func runInWindowTest(t *testing.T, directory string, filename string) { server := NewServer( address.Address(testDesc.Address), bson.NilObjectID, + defaultConnectionTimeout, withMonitoringDisabled(func(bool) bool { return true })) servers[testDesc.Address] = server @@ -1176,13 +1037,12 @@ func BenchmarkSelectServerFromDescription(b *testing.B) { Servers: servers, } - timeout := make(chan time.Time) b.ResetTimer() b.RunParallel(func(p *testing.PB) { b.ReportAllocs() for p.Next() { var c Topology - _, _ = c.selectServerFromDescription(desc, newServerSelectionState(selectNone, timeout)) + _, _ = c.selectServerFromDescription(desc, selectNone) } }) })