diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 21b7395a6e2..4e9b92bdef3 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -350,6 +350,23 @@ functions: chmod +x $i done + assume-ec2-role: + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + + run-oidc-auth-test-with-test-credentials: + - command: shell.exec + type: test + params: + working_dir: src/go.mongodb.org/mongo-driver + shell: bash + include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + export OIDC="oidc" + bash ${PROJECT_DIRECTORY}/etc/run-oidc-test.sh + run-make: - command: shell.exec type: test @@ -1949,6 +1966,10 @@ tasks: popd ./.evergreen/run-deployed-lambda-aws-tests.sh + - name: "oidc-auth-test-latest" + commands: + - func: "run-oidc-auth-test-with-test-credentials" + - name: "test-search-index" commands: - func: "bootstrap-mongo-orchestration" @@ -2231,6 +2252,31 @@ task_groups: tasks: - testazurekms-task + - name: testoidc_task_group + setup_group: + - func: fetch-source + - func: prepare-resources + - func: fix-absolute-paths + - func: make-files-executable + - func: assume-ec2-role + - command: shell.exec + params: + shell: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-latest + - name: test-aws-lambda-task-group setup_group: - func: fetch-source @@ -2564,3 +2610,13 @@ buildvariants: - name: testazurekms_task_group batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README - testazurekms-fail-task + + - name: testoidc-variant + display_name: "OIDC" + run_on: + - ubuntu2204-large + expansions: + GO_DIST: "/opt/golang/go1.20" + tasks: + - name: testoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README diff --git a/Makefile b/Makefile index 8ebcde8439f..33001650db1 100644 --- a/Makefile +++ b/Makefile @@ -127,6 +127,11 @@ evg-test-atlas-data-lake: evg-test-enterprise-auth: go run -tags gssapi ./internal/cmd/testentauth/main.go +.PHONY: evg-test-oidc-auth +evg-test-oidc-auth: + go run ./internal/cmd/testoidcauth/main.go + go run -race ./internal/cmd/testoidcauth/main.go + .PHONY: evg-test-kmip evg-test-kmip: go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH) DYLD_LIBRARY_PATH=$(MACOS_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./internal/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite diff --git a/etc/run-oidc-test.sh b/etc/run-oidc-test.sh new file mode 100644 index 00000000000..bc5eb997587 --- /dev/null +++ b/etc/run-oidc-test.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# run-oidc-test +# Runs oidc auth tests. +set -eu + +echo "Running MONGODB-OIDC authentication tests" + +OIDC_ENV="${OIDC_ENV:-"test"}" + +if [ $OIDC_ENV == "test" ]; then + # Make sure DRIVERS_TOOLS is set. + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi + source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh + +elif [ $OIDC_ENV == "azure" ]; then + source ./env.sh + +elif [ $OIDC_ENV == "gcp" ]; then + source ./secrets-export.sh + +else + echo "Unrecognized OIDC_ENV $OIDC_ENV" + exit 1 +fi + +export TEST_AUTH_OIDC=1 +export COVERAGE=1 +export AUTH="auth" + +make -s evg-test-oidc-auth diff --git a/internal/cmd/testoidcauth/main.go b/internal/cmd/testoidcauth/main.go new file mode 100644 index 00000000000..848ce06ca69 --- /dev/null +++ b/internal/cmd/testoidcauth/main.go @@ -0,0 +1,693 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "os" + "path" + "reflect" + "sync" + "time" + "unsafe" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth" +) + +var uriAdmin = os.Getenv("MONGODB_URI") +var uriSingle = os.Getenv("MONGODB_URI_SINGLE") + +// var uriMulti = os.Getenv("MONGODB_URI_MULTI") +var oidcTokenDir = os.Getenv("OIDC_TOKEN_DIR") + +//var oidcDomain = os.Getenv("OIDC_DOMAIN") + +//func explicitUser(user string) string { +// return fmt.Sprintf("%s@%s", user, oidcDomain) +//} + +func tokenFile(user string) string { + return path.Join(oidcTokenDir, user) +} + +func connectAdminClinet() (*mongo.Client, error) { + return mongo.Connect(options.Client().ApplyURI(uriAdmin)) +} + +func connectWithMachineCB(uri string, cb options.OIDCCallback) (*mongo.Client, error) { + cred := options.Credential{ + AuthMechanism: "MONGODB-OIDC", + OIDCMachineCallback: cb, + } + optsBuilder := options.Client().ApplyURI(uri).SetAuth(cred) + return mongo.Connect(optsBuilder) +} + +func connectWithMachineCBAndProperties(uri string, cb options.OIDCCallback, props map[string]string) (*mongo.Client, error) { + cred := options.Credential{ + AuthMechanism: "MONGODB-OIDC", + OIDCMachineCallback: cb, + AuthMechanismProperties: props, + } + optsBuilder := options.Client().ApplyURI(uri).SetAuth(cred) + return mongo.Connect(optsBuilder) +} + +func main() { + // be quiet linter + _ = tokenFile("test_user2") + + hasError := false + aux := func(test_name string, f func() error) { + fmt.Printf("%s...", test_name) + err := f() + if err != nil { + fmt.Println("Test Error: ", err) + fmt.Println("...Failed") + hasError = true + } else { + fmt.Println("...Ok") + } + } + aux("machine_1_1_callbackIsCalled", machine11callbackIsCalled) + aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine12callbackIsCalledOnlyOneForMultipleConnections) + aux("machine_2_1_validCallbackInputs", machine21validCallbackInputs) + aux("machine_2_3_oidcCallbackReturnMissingData", machine23oidcCallbackReturnMissingData) + aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) + aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) + aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) + aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) + aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) + aux("machine_4_2_readCommandsFailIfReauthenticationFails", machine42ReadCommandsFailIfReauthenticationFails) + aux("machine_4_3_writeCommandsFailIfReauthenticationFails", machine43WriteCommandsFailIfReauthenticationFails) + if hasError { + log.Fatal("One or more tests failed") + } +} + +func machine11callbackIsCalled() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_1_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_1_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine12callbackIsCalledOnlyOneForMultipleConnections() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_1_2: failed connecting client: %v", err) + } + + var wg sync.WaitGroup + + var findFailed error + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + coll := client.Database("test").Collection("test") + _, err := coll.Find(context.Background(), bson.D{}) + if err != nil { + findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v", err) + } + }() + } + + wg.Wait() + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d", callbackCount) + } + if callbackFailed != nil { + return callbackFailed + } + return findFailed +} + +func machine21validCallbackInputs() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + if args.RefreshToken != nil { + callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v", args.RefreshToken) + } + timeout, ok := ctx.Deadline() + if !ok { + callbackFailed = fmt.Errorf("machine_2_1: expected context to have deadline, got %v", ctx) + } + if timeout.Before(time.Now()) { + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v", timeout) + } + if args.Version < 1 { + callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d", args.Version) + } + if args.IDPInfo != nil { + callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v", args.IDPInfo) + } + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + fmt.Printf("machine_2_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_2_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_2_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine23oidcCallbackReturnMissingData() error { + callbackCount := 0 + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_2_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_2_3: should have failed to executed Find, but succeeded") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d", callbackCount) + } + return nil +} + +func machine24invalidClientConfigurationWithCallback() error { + _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }, + map[string]string{"ENVIRONMENT": "test"}, + ) + if err == nil { + return fmt.Errorf("machine_2_4: succeeded building client when it should fail") + } + return nil +} + +func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_3_1: failed connecting client: %v", err) + } + + // Poison the cache with a random token + clientElem := reflect.ValueOf(client).Elem() + authenticatorField := clientElem.FieldByName("authenticator") + authenticatorField = reflect.NewAt( + authenticatorField.Type(), + unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + // this is the only usage of the x packages in the test, showing the the public interface is + // correct. + authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine32authFailuresWithoutCachedTokensReturnsAnError() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_3_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_2: Find ucceeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + + defer func() { _ = adminClient.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 20}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_3_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_3: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_3: failed executing Find: %v", err) + } + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine41ReauthenticationSucceeds() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer func() { _ = adminClient.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_1: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine42ReadCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer func() { _ = adminClient.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_2: failed executing Find: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_2: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_2: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_2: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine43WriteCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer func() { _ = adminClient.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer func() { _ = client.Disconnect(context.Background()) }() + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_3: failed executing Insert: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "insert", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_3: Insert succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_3: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index eef39997934..79613bd4213 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -1582,8 +1582,7 @@ func TestCollection(t *testing.T) { _, err := mt.Coll.Indexes().CreateOne(context.TODO(), indexModel) assert.NoError(mt, err, "failed to create index") - err = mt.Coll.Indexes().DropOne(context.Background(), "username_1") - assert.NoError(mt, err) + _ = mt.Coll.Indexes().DropOne(context.Background(), "username_1") }) }) diff --git a/internal/integration/mtest/opmsg_deployment.go b/internal/integration/mtest/opmsg_deployment.go index 8b521b19f48..31220918f68 100644 --- a/internal/integration/mtest/opmsg_deployment.go +++ b/internal/integration/mtest/opmsg_deployment.go @@ -60,6 +60,13 @@ func (c *connection) Write(context.Context, []byte) error { return nil } +func (c *connection) OIDCTokenGenID() uint64 { + return 0 +} + +func (c *connection) SetOIDCTokenGenID(uint64) { +} + // Read returns the next response in the connection's list of responses. func (c *connection) Read(_ context.Context) ([]byte, error) { var dst []byte diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index b62675151a6..26ad74f0162 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -189,7 +189,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -259,7 +259,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -390,7 +390,8 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI). - Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger) + Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger). + Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { diff --git a/mongo/change_stream.go b/mongo/change_stream.go index ead95544bde..bde1ebc800e 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -173,7 +173,8 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in ReadPreference(config.readPreference).ReadConcern(config.readConcern). Deployment(cs.client.deployment).ClusterClock(cs.client.clock). CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). - ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout) + ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout). + Authenticator(cs.client.authenticator) if cs.options.Collation != nil { cs.aggregate.Collation(bsoncore.Document(toDocument(cs.options.Collation))) diff --git a/mongo/client.go b/mongo/client.go index 405a5f2714f..04ebcb4eb27 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -27,6 +27,7 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt" mcopts "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt/options" @@ -81,6 +82,7 @@ type Client struct { metadataClientFLE *Client internalClientFLE *Client encryptedFieldsMap map[string]interface{} + authenticator driver.Authenticator } // Connect creates a new Client and then initializes it using the Connect method. @@ -212,7 +214,39 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { args.MaxPoolSize = &defaultMaxPoolSize } - cfg, err := topology.NewConfigFromOptions(args, client.clock) + if args.Auth != nil { + var oidcMachineCallback auth.OIDCCallback + if args.Auth.OIDCMachineCallback != nil { + oidcMachineCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := args.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(oargs)) + return (*driver.OIDCCredential)(cred), err + } + } + + var oidcHumanCallback auth.OIDCCallback + if args.Auth.OIDCHumanCallback != nil { + oidcHumanCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := args.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(oargs)) + return (*driver.OIDCCredential)(cred), err + } + } + + // Create an authenticator for the client + client.authenticator, err = auth.CreateAuthenticator(args.Auth.AuthMechanism, &auth.Cred{ + Source: args.Auth.AuthSource, + Username: args.Auth.Username, + Password: args.Auth.Password, + PasswordSet: args.Auth.PasswordSet, + Props: args.Auth.AuthMechanismProperties, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, + }, args.HTTPClient) + if err != nil { + return nil, err + } + } + + cfg, err := topology.NewConfigFromOptionsWithAuthenticator(args, client.clock, client.authenticator) if err != nil { return nil, err } @@ -240,6 +274,19 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { return client, nil } +// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent +// public type *options.OIDCArgs. +func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { + if args == nil { + return nil + } + return &options.OIDCArgs{ + Version: args.Version, + IDPInfo: (*options.IDPInfo)(args.IDPInfo), + RefreshToken: args.RefreshToken, + } +} + // connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // @@ -740,7 +787,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.cryptFLE). - ServerAPI(c.serverAPI).Timeout(c.timeout) + ServerAPI(c.serverAPI).Timeout(c.timeout).Authenticator(c.authenticator) if lda.NameOnly != nil { op = op.NameOnly(*lda.NameOnly) diff --git a/mongo/client_test.go b/mongo/client_test.go index 72e3ee09625..ee56449ce64 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -11,6 +11,7 @@ import ( "errors" "math" "os" + "reflect" "testing" "time" @@ -19,11 +20,13 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/tag" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" @@ -516,3 +519,76 @@ func TestClient(t *testing.T) { assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error()) }) } + +// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs +// into an options.OIDCArgs. +func TestConvertOIDCArgs(t *testing.T) { + refreshToken := "test refresh token" + + testCases := []struct { + desc string + args *driver.OIDCArgs + }{ + { + desc: "populated args", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: &driver.IDPInfo{ + Issuer: "test issuer", + ClientID: "test client ID", + RequestScopes: []string{"test scope 1", "test scope 2"}, + }, + RefreshToken: &refreshToken, + }, + }, + { + desc: "nil", + args: nil, + }, + { + desc: "nil IDPInfo and RefreshToken", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: nil, + RefreshToken: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + got := convertOIDCArgs(tc.args) + + if tc.args == nil { + assert.Nil(t, got, "expected nil when input is nil") + return + } + + require.Equal(t, + 3, + reflect.ValueOf(*tc.args).NumField(), + "expected the driver.OIDCArgs struct to have exactly 3 fields") + require.Equal(t, + 3, + reflect.ValueOf(*got).NumField(), + "expected the options.OIDCArgs struct to have exactly 3 fields") + + assert.Equal(t, + tc.args.Version, + got.Version, + "expected Version field to be equal") + assert.EqualValues(t, + tc.args.IDPInfo, + got.IDPInfo, + "expected IDPInfo field to be convertible to equal values") + assert.Equal(t, + tc.args.RefreshToken, + got.RefreshToken, + "expected RefreshToken field to be equal") + }) + } +} diff --git a/mongo/collection.go b/mongo/collection.go index e71bf0e7e7f..a73ff90760b 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -304,7 +304,7 @@ func (coll *Collection) insert( ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) args, err := mongoutil.NewOptions[options.InsertManyOptions](opts...) if err != nil { @@ -521,7 +521,7 @@ func (coll *Collection) delete( ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) if args.Comment != nil { comment, err := marshalValue(args.Comment, coll.bsonOpts, coll.registry) if err != nil { @@ -655,7 +655,7 @@ func (coll *Collection) updateOrReplace( Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Hint(args.Hint != nil). ArrayFilters(args.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) if args.Let != nil { let, err := marshal(args.Let, coll.bsonOpts, coll.registry) if err != nil { @@ -948,7 +948,8 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption Crypt(a.client.cryptFLE). ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). - Timeout(a.client.timeout) + Timeout(a.client.timeout). + Authenticator(a.client.authenticator) if args.AllowDiskUse != nil { op.AllowDiskUse(*args.AllowDiskUse) @@ -1071,7 +1072,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).Authenticator(coll.client.authenticator) if args.Collation != nil { op.Collation(bsoncore.Document(toDocument(args.Collation))) } @@ -1165,7 +1166,7 @@ func (coll *Collection) EstimatedDocumentCount( Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).Authenticator(coll.client.authenticator) if args.Comment != nil { comment, err := marshalValue(args.Comment, coll.bsonOpts, coll.registry) @@ -1241,7 +1242,7 @@ func (coll *Collection) Distinct( Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).Authenticator(coll.client.authenticator) if args.Collation != nil { op.Collation(bsoncore.Document(toDocument(args.Collation))) @@ -1334,7 +1335,7 @@ func (coll *Collection) find(ctx context.Context, filter interface{}, CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) cursorOpts := coll.client.createBaseCursorOptions() @@ -1592,7 +1593,7 @@ func (coll *Collection) FindOneAndDelete( return &SingleResult{err: fmt.Errorf("failed to construct options from builder: %w", err)} } - op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Authenticator(coll.client.authenticator) if args.Collation != nil { op = op.Collation(bsoncore.Document(toDocument(args.Collation))) } @@ -1680,7 +1681,7 @@ func (coll *Collection) FindOneAndReplace( } op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsoncore.TypeEmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Authenticator(coll.client.authenticator) if args.BypassDocumentValidation != nil && *args.BypassDocumentValidation { op = op.BypassDocumentValidation(*args.BypassDocumentValidation) } @@ -1773,7 +1774,7 @@ func (coll *Collection) FindOneAndUpdate( return &SingleResult{err: fmt.Errorf("failed to construct options from builder: %w", err)} } - op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Authenticator(coll.client.authenticator) u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { @@ -1982,7 +1983,8 @@ func (coll *Collection) drop(ctx context.Context) error { ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + Authenticator(coll.client.authenticator) err = op.Execute(ctx) // ignore namespace not found errors diff --git a/mongo/database.go b/mongo/database.go index 605fc28f984..2a80fbf238c 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -211,7 +211,7 @@ func (db *Database) processRunCommand( ServerSelector(readSelect).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment). Crypt(db.client.cryptFLE).ReadPreference(args.ReadPreference).ServerAPI(db.client.serverAPI). - Timeout(db.client.timeout).Logger(db.client.logger), sess, nil + Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.client.authenticator), sess, nil } // RunCommand executes the given command against the database. @@ -339,7 +339,7 @@ func (db *Database) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) err = op.Execute(ctx) @@ -469,7 +469,7 @@ func (db *Database) ListCollections( Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI).Timeout(db.client.timeout) + ServerAPI(db.client.serverAPI).Timeout(db.client.timeout).Authenticator(db.client.authenticator) cursorOpts := db.client.createBaseCursorOptions() @@ -759,7 +759,7 @@ func (db *Database) createCollectionOperation( return nil, fmt.Errorf("failed to construct options from builder: %w", err) } - op := operation.NewCreate(name).ServerAPI(db.client.serverAPI) + op := operation.NewCreate(name).ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) if args.Capped != nil { op.Capped(*args.Capped) @@ -896,7 +896,7 @@ func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pip op := operation.NewCreate(viewName). ViewOn(viewOn). Pipeline(pipelineArray). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) args, err := mongoutil.NewOptions(opts...) if err != nil { return fmt.Errorf("failed to construct options from builder: %w", err) diff --git a/mongo/index_view.go b/mongo/index_view.go index 48f88c60356..748957da1b0 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -94,7 +94,7 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE).Authenticator(iv.coll.client.authenticator) cursorOpts := iv.coll.client.createBaseCursorOptions() @@ -277,7 +277,7 @@ func (iv IndexView) CreateMany( Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE).Authenticator(iv.coll.client.authenticator) if args.CommitQuorum != nil { commitQuorum, err := marshalValue(args.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { @@ -413,7 +413,7 @@ func (iv IndexView) drop(ctx context.Context, index any, _ ...options.Lister[opt ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE).Authenticator(iv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index cccdeca3faa..90b503104d5 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -111,6 +111,34 @@ type Credential struct { Username string Password string PasswordSet bool + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + +// OIDCCallback is the type for both Human and Machine Callback flows. +// RefreshToken will always be nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with +// an Identity Provider. +type IDPInfo struct { + Issuer string + ClientID string + RequestScopes []string } // BSONOptions are optional BSON marshaling and unmarshaling behaviors. diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index ce00cce7dc3..4b9b80d2e17 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -159,7 +159,7 @@ func (siv SearchIndexView) CreateMany( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { @@ -214,7 +214,7 @@ func (siv SearchIndexView) DropOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { @@ -268,7 +268,7 @@ func (siv SearchIndexView) UpdateOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) return op.Execute(ctx) } diff --git a/mongo/session.go b/mongo/session.go index 95f21038a55..da8cd9c951d 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -239,7 +239,8 @@ func (s *Session) AbortTransaction(ctx context.Context) error { _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor). - RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI).Execute(ctx) + RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI). + Authenticator(s.client.authenticator).Execute(ctx) s.clientSession.Aborting = false _ = s.clientSession.AbortTransaction() @@ -273,7 +274,7 @@ func (s *Session) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI) + ServerAPI(s.client.serverAPI).Authenticator(s.client.authenticator) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 9694654b9ef..240fc22e3d3 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -21,7 +21,7 @@ import ( ) // AuthenticatorFactory constructs an authenticator. -type AuthenticatorFactory func(cred *Cred) (Authenticator, error) +type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error) var authFactories = make(map[string]AuthenticatorFactory) @@ -34,12 +34,13 @@ func init() { RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator) RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator) RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator) + RegisterAuthenticatorFactory(MongoDBOIDC, newOIDCAuthenticator) } // CreateAuthenticator creates an authenticator. -func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) { +func CreateAuthenticator(name string, cred *Cred, httpClient *http.Client) (Authenticator, error) { if f, ok := authFactories[name]; ok { - return f(cred) + return f(cred, httpClient) } return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil) @@ -62,7 +63,6 @@ type HandshakeOptions struct { ClusterClock *session.ClusterClock ServerAPI *driver.ServerAPIOptions LoadBalanced bool - HTTPClient *http.Client } type authHandshaker struct { @@ -102,12 +102,17 @@ func (ah *authHandshaker) GetHandshakeInformation( return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err) } - firstMsg, err := ah.conversation.FirstMessage() - if err != nil { - return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) - } + // It is possible for the speculative conversation to be nil even without error if the authenticator + // cannot perform speculative authentication. An example of this is MONGODB-OIDC when there is + // no AccessToken in the cache. + if ah.conversation != nil { + firstMsg, err := ah.conversation.FirstMessage() + if err != nil { + return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) + } - op = op.SpeculativeAuthenticate(firstMsg) + op = op.SpeculativeAuthenticate(firstMsg) + } } } @@ -130,12 +135,11 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connec } if performAuth(conn.Description()) && ah.options.Authenticator != nil { - cfg := &Config{ + cfg := &driver.AuthConfig{ Connection: conn, ClusterClock: ah.options.ClusterClock, HandshakeInfo: ah.handshakeInfo, ServerAPI: ah.options.ServerAPI, - HTTPClient: ah.options.HTTPClient, } if err := ah.authenticate(ctx, cfg); err != nil { @@ -149,7 +153,7 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connec return ah.wrapped.FinishHandshake(ctx, conn) } -func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error { +func (ah *authHandshaker) authenticate(ctx context.Context, cfg *driver.AuthConfig) error { // If the initial hello reply included a response to the speculative authentication attempt, we only need to // conduct the remainder of the conversation. if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil { @@ -183,10 +187,7 @@ type Config struct { } // Authenticator handles authenticating a connection. -type Authenticator interface { - // Auth authenticates the connection. - Auth(context.Context, *Config) error -} +type Authenticator = driver.Authenticator func newAuthError(msg string, inner error) error { return &Error{ diff --git a/x/mongo/driver/auth/auth_test.go b/x/mongo/driver/auth/auth_test.go index 082401ed200..c4103cf3b77 100644 --- a/x/mongo/driver/auth/auth_test.go +++ b/x/mongo/driver/auth/auth_test.go @@ -7,6 +7,7 @@ package auth_test import ( + "net/http" "testing" "github.com/google/go-cmp/cmp" @@ -39,7 +40,7 @@ func TestCreateAuthenticator(t *testing.T) { PasswordSet: true, } - a, err := CreateAuthenticator(test.name, cred) + a, err := CreateAuthenticator(test.name, cred, &http.Client{}) require.NoError(t, err) require.IsType(t, test.auth, a) }) diff --git a/x/mongo/driver/auth/conversation.go b/x/mongo/driver/auth/conversation.go index 7159f4e2bea..13839d7f8d9 100644 --- a/x/mongo/driver/auth/conversation.go +++ b/x/mongo/driver/auth/conversation.go @@ -10,6 +10,7 @@ import ( "context" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) // SpeculativeConversation represents an authentication conversation that can be merged with the initial connection @@ -22,7 +23,7 @@ import ( // authenticate the provided connection. type SpeculativeConversation interface { FirstMessage() (bsoncore.Document, error) - Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error + Finish(ctx context.Context, cfg *driver.AuthConfig, firstResponse bsoncore.Document) error } // SpeculativeAuthenticator represents an authenticator that supports speculative authentication. diff --git a/x/mongo/driver/auth/cred.go b/x/mongo/driver/auth/cred.go index 7b2b8f17d00..ed337773854 100644 --- a/x/mongo/driver/auth/cred.go +++ b/x/mongo/driver/auth/cred.go @@ -6,11 +6,9 @@ package auth -// Cred is a user's credential. -type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string -} +import ( + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" +) + +// Cred is the type of user credential +type Cred = driver.Cred diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index 6f2ca5224a6..6982ccc0122 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -9,10 +9,13 @@ package auth import ( "context" "fmt" + "net/http" + + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) -func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { - scram, err := newScramSHA256Authenticator(cred) +func newDefaultAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + scram, err := newScramSHA256Authenticator(cred, httpClient) if err != nil { return nil, newAuthError("failed to create internal authenticator", err) } @@ -25,6 +28,7 @@ func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { return &DefaultAuthenticator{ Cred: cred, speculativeAuthenticator: speculative, + httpClient: httpClient, }, nil } @@ -36,6 +40,8 @@ type DefaultAuthenticator struct { // The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing // the initial hello, SCRAM-SHA-256 is used for the speculative attempt. speculativeAuthenticator SpeculativeAuthenticator + + httpClient *http.Client } var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil) @@ -46,17 +52,17 @@ func (a *DefaultAuthenticator) CreateSpeculativeConversation() (SpeculativeConve } // Auth authenticates the connection. -func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { +func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { var actual Authenticator var err error switch chooseAuthMechanism(cfg) { case SCRAMSHA256: - actual, err = newScramSHA256Authenticator(a.Cred) + actual, err = newScramSHA256Authenticator(a.Cred, a.httpClient) case SCRAMSHA1: - actual, err = newScramSHA1Authenticator(a.Cred) + actual, err = newScramSHA1Authenticator(a.Cred, a.httpClient) default: - actual, err = newMongoDBCRAuthenticator(a.Cred) + actual, err = newMongoDBCRAuthenticator(a.Cred, a.httpClient) } if err != nil { @@ -66,10 +72,15 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { return actual.Auth(ctx, cfg) } +// Reauth reauthenticates the connection. +func (a *DefaultAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("DefaultAuthenticator does not support reauthentication", nil) +} + // If a server provides a list of supported mechanisms, we choose // SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1. // Otherwise, we decide based on what is supported. -func chooseAuthMechanism(cfg *Config) string { +func chooseAuthMechanism(cfg *driver.AuthConfig) string { if saslSupportedMechs := cfg.HandshakeInfo.SaslSupportedMechs; saslSupportedMechs != nil { for _, v := range saslSupportedMechs { if v == SCRAMSHA256 { diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 2b68417efa5..9907eb4db4a 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -14,14 +14,16 @@ import ( "context" "fmt" "net" + "net/http" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth/internal/gssapi" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("GSSAPI source must be empty or $external", nil) } @@ -43,7 +45,7 @@ type GSSAPIAuthenticator struct { } // Auth authenticates the connection. -func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { +func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { target := cfg.Connection.Description().Addr.String() hostname, _, err := net.SplitHostPort(target) if err != nil { @@ -57,3 +59,8 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { } return ConductSaslConversation(ctx, cfg, "$external", client) } + +// Reauth reauthenticates the connection. +func (a *GSSAPIAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("GSSAPI does not support reauthentication", nil) +} diff --git a/x/mongo/driver/auth/gssapi_not_enabled.go b/x/mongo/driver/auth/gssapi_not_enabled.go index 7ba5fe860ce..e50553c7a1b 100644 --- a/x/mongo/driver/auth/gssapi_not_enabled.go +++ b/x/mongo/driver/auth/gssapi_not_enabled.go @@ -9,9 +9,11 @@ package auth +import "net/http" + // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(*Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil) } diff --git a/x/mongo/driver/auth/gssapi_not_supported.go b/x/mongo/driver/auth/gssapi_not_supported.go index 10312c228ee..12046ff67c2 100644 --- a/x/mongo/driver/auth/gssapi_not_supported.go +++ b/x/mongo/driver/auth/gssapi_not_supported.go @@ -11,12 +11,13 @@ package auth import ( "fmt" + "net/http" "runtime" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil) } diff --git a/x/mongo/driver/auth/gssapi_test.go b/x/mongo/driver/auth/gssapi_test.go index 968b3ac732e..eedfe428e99 100644 --- a/x/mongo/driver/auth/gssapi_test.go +++ b/x/mongo/driver/auth/gssapi_test.go @@ -14,6 +14,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/v2/mongo/address" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" @@ -44,7 +45,7 @@ func TestGSSAPIAuthenticator(t *testing.T) { mnetconn := mnet.NewConnection(chanconn) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) if err == nil { t.Fatalf("expected err, got nil") } diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 14fd637d970..cdb60223441 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -9,19 +9,24 @@ package auth import ( "context" "errors" + "net/http" "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" "go.mongodb.org/mongo-driver/v2/internal/credproviders" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth/creds" ) // MongoDBAWS is the mechanism name for MongoDBAWS. const MongoDBAWS = "MONGODB-AWS" -func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil) } + if httpClient == nil { + return nil, errors.New("httpClient must not be nil") + } return &MongoDBAWSAuthenticator{ source: cred.Source, credentials: &credproviders.StaticProvider{ @@ -32,6 +37,7 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { SessionToken: cred.Props["AWS_SESSION_TOKEN"], }, }, + httpClient: httpClient, }, nil } @@ -39,15 +45,12 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { type MongoDBAWSAuthenticator struct { source string credentials *credproviders.StaticProvider + httpClient *http.Client } // Auth authenticates the connection. -func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { - httpClient := cfg.HTTPClient - if httpClient == nil { - return errors.New("cfg.HTTPClient must not be nil") - } - providers := creds.NewAWSCredentialProvider(httpClient, a.credentials) +func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { + providers := creds.NewAWSCredentialProvider(a.httpClient, a.credentials) adapter := &awsSaslAdapter{ conversation: &awsConversation{ credentials: providers.Cred, @@ -60,6 +63,11 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("AWS authentication does not support reauthentication", nil) +} + type awsSaslAdapter struct { conversation *awsConversation } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 643d3a4c2ca..55ec36fa7d1 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "io" + "net/http" // Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need // to use MD5 here to implement the MONGODB-CR specification. @@ -28,7 +29,7 @@ import ( // MongoDB 4.0. const MONGODBCR = "MONGODB-CR" -func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBCRAuthenticator{ DB: cred.Source, Username: cred.Username, @@ -50,7 +51,7 @@ type MongoDBCRAuthenticator struct { // // The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in // MongoDB 4.0. -func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { +func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { db := a.DB if db == "" { @@ -97,6 +98,11 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBCRAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("MONGODB-CR does not support reauthentication", nil) +} + func (a *MongoDBCRAuthenticator) createKey(nonce string) string { // Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to // implement the MONGODB-CR specification. diff --git a/x/mongo/driver/auth/mongodbcr_test.go b/x/mongo/driver/auth/mongodbcr_test.go index 70f4e6868c3..8bc956d499f 100644 --- a/x/mongo/driver/auth/mongodbcr_test.go +++ b/x/mongo/driver/auth/mongodbcr_test.go @@ -13,6 +13,7 @@ import ( "strings" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" . "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" @@ -49,7 +50,7 @@ func TestMongoDBCRAuthenticator_Fails(t *testing.T) { mnetconn := mnet.NewConnection(c) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) if err == nil { t.Fatalf("expected an error but got none") } @@ -90,7 +91,7 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) { mnetconn := mnet.NewConnection(c) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) if err != nil { t.Fatalf("expected no error but got \"%s\"", err) } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go new file mode 100644 index 00000000000..cd3a922f3f6 --- /dev/null +++ b/x/mongo/driver/auth/oidc.go @@ -0,0 +1,344 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" +) + +// MongoDBOIDC is the string constant for the MONGODB-OIDC authentication mechanism. +const MongoDBOIDC = "MONGODB-OIDC" + +// TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider +// const tokenResourceProp = "TOKEN_RESOURCE" +const environmentProp = "ENVIRONMENT" + +const resourceProp = "TOKEN_RESOURCE" + +// GODRIVER-3249 OIDC: Handle all possible OIDC configuration errors +//const allowedHostsProp = "ALLOWED_HOSTS" + +const azureEnvironmentValue = "azure" +const gcpEnvironmentValue = "gcp" +const testEnvironmentValue = "test" + +const apiVersion = 1 +const invalidateSleepTimeout = 100 * time.Millisecond + +// The CSOT specification says to apply a 1-minute timeout if "CSOT is not applied". That's +// ambiguous for the v1.x Go Driver because it could mean either "no timeout provided" or "CSOT not +// enabled". Always use a maximum timeout duration of 1 minute, allowing us to ignore the ambiguity. +// Contexts with a shorter timeout are unaffected. +const machineCallbackTimeout = 60 * time.Second + +//GODRIVER-3246 OIDC: Implement Human Callback Mechanism +//var defaultAllowedHosts = []string{ +// "*.mongodb.net", +// "*.mongodb-qa.net", +// "*.mongodb-dev.net", +// "*.mongodbgov.net", +// "localhost", +// "127.0.0.1", +// "::1", +//} + +// OIDCCallback is a function that takes a context and OIDCArgs and returns an OIDCCredential. +type OIDCCallback = driver.OIDCCallback + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs = driver.OIDCArgs + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential = driver.OIDCCredential + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. +type IDPInfo = driver.IDPInfo + +var _ driver.Authenticator = (*OIDCAuthenticator)(nil) +var _ SpeculativeAuthenticator = (*OIDCAuthenticator)(nil) +var _ SaslClient = (*oidcOneStep)(nil) + +// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, +// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality +// is only for the OIDC Human flow. +type OIDCAuthenticator struct { + mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. + + AuthMechanismProperties map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback + + userName string + httpClient *http.Client + accessToken string + refreshToken *string + idpInfo *IDPInfo + tokenGenID uint64 +} + +// SetAccessToken allows for manually setting the access token for the OIDCAuthenticator, this is +// only for testing purposes. +func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { + oa.mu.Lock() + defer oa.mu.Unlock() + oa.accessToken = accessToken +} + +func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + if cred.Password != "" { + return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) + } + if cred.Props != nil { + if env, ok := cred.Props[environmentProp]; ok { + switch strings.ToLower(env) { + case azureEnvironmentValue: + fallthrough + case gcpEnvironmentValue: + if _, ok := cred.Props[resourceProp]; !ok { + return nil, fmt.Errorf("%q must be specified for %q %q", resourceProp, env, environmentProp) + } + fallthrough + case testEnvironmentValue: + if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil { + return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, environmentProp) + } + } + } + } + oa := &OIDCAuthenticator{ + userName: cred.Username, + httpClient: httpClient, + AuthMechanismProperties: cred.Props, + OIDCMachineCallback: cred.OIDCMachineCallback, + OIDCHumanCallback: cred.OIDCHumanCallback, + } + return oa, nil +} + +type oidcOneStep struct { + userName string + accessToken string +} + +func jwtStepRequest(accessToken string) []byte { + return bsoncore.NewDocumentBuilder(). + AppendString("jwt", accessToken). + Build() +} + +// TODO GODRIVER-3246: Implement OIDC human flow +//func principalStepRequest(principal string) []byte { +// doc := bsoncore.NewDocumentBuilder() +// if principal != "" { +// doc.AppendString("n", principal) +// } +// return doc.Build() +//} + +func (oos *oidcOneStep) Start() (string, []byte, error) { + return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil +} + +func (oos *oidcOneStep) Next([]byte) ([]byte, error) { + return nil, newAuthError("unexpected step in OIDC authentication", nil) +} + +func (*oidcOneStep) Completed() bool { + return true +} + +func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { + env, ok := oa.AuthMechanismProperties[environmentProp] + if !ok { + return nil, nil + } + + switch env { + // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider + // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider + // This is here just to pass the linter, it will be fixed in one of the above tickets. + case azureEnvironmentValue, gcpEnvironmentValue: + return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { + return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + }, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + } + + return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) +} + +func (oa *OIDCAuthenticator) getAccessToken( + ctx context.Context, + conn *mnet.Connection, + args *OIDCArgs, + callback OIDCCallback, +) (string, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + + if oa.accessToken != "" { + return oa.accessToken, nil + } + + cred, err := callback(ctx, args) + if err != nil { + return "", err + } + + oa.accessToken = cred.AccessToken + oa.tokenGenID++ + conn.SetOIDCTokenGenID(oa.tokenGenID) + if cred.RefreshToken != nil { + oa.refreshToken = cred.RefreshToken + } + return cred.AccessToken, nil +} + +// TODO GODRIVER-3246: Implement OIDC human flow +// This should only be called with the Mutex held. +//func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( +// ctx context.Context, +// callback OIDCCallback, +// refreshToken string, +//) (string, error) { +// +// cred, err := callback(ctx, &OIDCArgs{ +// Version: apiVersion, +// IDPInfo: oa.idpInfo, +// RefreshToken: &refreshToken, +// }) +// if err != nil { +// return "", err +// } +// +// oa.accessToken = cred.AccessToken +// oa.tokenGenID++ +// oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) +// return cred.AccessToken, nil +//} + +// invalidateAccessToken invalidates the access token, if the force flag is set to true (which is +// only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the +// tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, +// but this is a safety check, since extra invalidation is only a performance impact, not a +// correctness impact. +func (oa *OIDCAuthenticator) invalidateAccessToken(conn *mnet.Connection) { + oa.mu.Lock() + defer oa.mu.Unlock() + tokenGenID := conn.OIDCTokenGenID() + // If the connection used in a Reauth is a new connection it will not have a correct tokenGenID, + // it will instead be set to 0. In the absence of information, the only safe thing to do is to + // invalidate the cached accessToken. + if tokenGenID == 0 || tokenGenID >= oa.tokenGenID { + oa.accessToken = "" + conn.SetOIDCTokenGenID(0) + } +} + +// Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the +// driver.Authenticator interface. +func (oa *OIDCAuthenticator) Reauth(ctx context.Context, cfg *driver.AuthConfig) error { + oa.invalidateAccessToken(cfg.Connection) + return oa.Auth(ctx, cfg) +} + +// Auth authenticates the connection. +func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { + var err error + + if cfg == nil { + return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil) + } + conn := cfg.Connection + + oa.mu.Lock() + cachedAccessToken := oa.accessToken + oa.mu.Unlock() + + if cachedAccessToken != "" { + err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + userName: oa.userName, + accessToken: cachedAccessToken, + }) + if err == nil { + return nil + } + // this seems like it could be incorrect since we could be inavlidating an access token that + // has already been replaced by a different auth attempt, but the TokenGenID will prevernt + // that from happening. + oa.invalidateAccessToken(conn) + time.Sleep(invalidateSleepTimeout) + } + + if oa.OIDCHumanCallback != nil { + return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback) + } + + // Handle user provided or automatic provider machine callback. + var machineCallback OIDCCallback + if oa.OIDCMachineCallback != nil { + machineCallback = oa.OIDCMachineCallback + } else { + machineCallback, err = oa.providerCallback() + if err != nil { + return fmt.Errorf("error getting built-in OIDC provider: %w", err) + } + } + + if machineCallback != nil { + return oa.doAuthMachine(ctx, cfg, machineCallback) + } + return newAuthError("no OIDC callback provided", nil) +} + +func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *driver.AuthConfig, _ OIDCCallback) error { + // TODO GODRIVER-3246: Implement OIDC human flow + return newAuthError("OIDC", fmt.Errorf("human flow not implemented yet, %v", oa.idpInfo)) +} + +func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.AuthConfig, machineCallback OIDCCallback) error { + subCtx, cancel := context.WithTimeout(ctx, machineCallbackTimeout) + accessToken, err := oa.getAccessToken(subCtx, + cfg.Connection, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: nil, + RefreshToken: nil, + }, + machineCallback) + cancel() + if err != nil { + return err + } + return ConductSaslConversation( + ctx, + cfg, + "$external", + &oidcOneStep{accessToken: accessToken}, + ) +} + +// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. +func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + accessToken := oa.accessToken + if accessToken == "" { + return nil, nil // Skip speculative auth. + } + + return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil +} diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 532d43e39f5..fc3e7b08bc8 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -8,12 +8,15 @@ package auth import ( "context" + "net/http" + + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) // PLAIN is the mechanism name for PLAIN. const PLAIN = "PLAIN" -func newPlainAuthenticator(cred *Cred) (Authenticator, error) { +func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &PlainAuthenticator{ Username: cred.Username, Password: cred.Password, @@ -27,13 +30,18 @@ type PlainAuthenticator struct { } // Auth authenticates the connection. -func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { +func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{ username: a.Username, password: a.Password, }) } +// Reauth reauthenticates the connection. +func (a *PlainAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("Plain authentication does not support reauthentication", nil) +} + type plainSaslClient struct { username string password string diff --git a/x/mongo/driver/auth/plain_test.go b/x/mongo/driver/auth/plain_test.go index 5ee64748edb..d50e3f5b226 100644 --- a/x/mongo/driver/auth/plain_test.go +++ b/x/mongo/driver/auth/plain_test.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" . "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" @@ -51,7 +52,7 @@ func TestPlainAuthenticator_Fails(t *testing.T) { mnetconn := mnet.NewConnection(c) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) if err == nil { t.Fatalf("expected an error but got none") } @@ -96,7 +97,7 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) { mnetconn := mnet.NewConnection(c) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) if err == nil { t.Fatalf("expected an error but got none") } @@ -136,7 +137,7 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) { mnetconn := mnet.NewConnection(c) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) if err != nil { t.Fatalf("expected no error but got \"%s\"", err) } @@ -183,7 +184,7 @@ func TestPlainAuthenticator_SucceedsBoolean(t *testing.T) { mnetconn := mnet.NewConnection(c) - err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) + err := authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) require.NoError(t, err, "Auth error") require.Len(t, c.Written, 1, "expected 1 messages to be sent") diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 124aae137c0..8a98dac8cab 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -94,7 +94,7 @@ type saslResponse struct { } // Finish completes the conversation based on the first server response to authenticate the given connection. -func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error { +func (sc *saslConversation) Finish(ctx context.Context, cfg *driver.AuthConfig, firstResponse bsoncore.Document) error { if closer, ok := sc.client.(SaslClientCloser); ok { defer closer.Close() } @@ -153,10 +153,9 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon } // ConductSaslConversation runs a full SASL conversation to authenticate the given connection. -func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { +func ConductSaslConversation(ctx context.Context, cfg *driver.AuthConfig, authSource string, client SaslClient) error { // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) - saslStartDoc, err := conversation.FirstMessage() if err != nil { return newError(err, conversation.mechanism) diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index 963dd5f338b..98a2c28076c 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -14,10 +14,12 @@ package auth import ( "context" + "net/http" "github.com/xdg-go/scram" "github.com/xdg-go/stringprep" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) const ( @@ -35,7 +37,7 @@ var ( ) ) -func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passdigest := mongoPasswordDigest(cred.Username, cred.Password) client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "") if err != nil { @@ -49,7 +51,7 @@ func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { }, nil } -func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passprep, err := stringprep.SASLprep.Prepare(cred.Password) if err != nil { return nil, newAuthError("error SASLprepping password", err) @@ -76,7 +78,7 @@ type ScramAuthenticator struct { var _ SpeculativeAuthenticator = (*ScramAuthenticator)(nil) // Auth authenticates the provided connection by conducting a full SASL conversation. -func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { +func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { err := ConductSaslConversation(ctx, cfg, a.source, a.createSaslClient()) if err != nil { return newAuthError("sasl conversation error", err) @@ -84,6 +86,11 @@ func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *ScramAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("SCRAM does not support reauthentication", nil) +} + // CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { return newSaslConversation(a.createSaslClient(), a.source, true), nil diff --git a/x/mongo/driver/auth/scram_test.go b/x/mongo/driver/auth/scram_test.go index 9b793bc31b4..a5d8c80796f 100644 --- a/x/mongo/driver/auth/scram_test.go +++ b/x/mongo/driver/auth/scram_test.go @@ -8,10 +8,12 @@ package auth import ( "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" @@ -39,7 +41,7 @@ func TestSCRAM(t *testing.T) { t.Run("conversation", func(t *testing.T) { testCases := []struct { name string - createAuthenticatorFn func(*Cred) (Authenticator, error) + createAuthenticatorFn func(*Cred, *http.Client) (Authenticator, error) payloads [][]byte nonce string }{ @@ -50,11 +52,13 @@ func TestSCRAM(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - authenticator, err := tc.createAuthenticatorFn(&Cred{ - Username: "user", - Password: "pencil", - Source: "admin", - }) + authenticator, err := tc.createAuthenticatorFn( + &Cred{ + Username: "user", + Password: "pencil", + Source: "admin", + }, + &http.Client{}) assert.Nil(t, err, "error creating authenticator: %v", err) sa, _ := authenticator.(*ScramAuthenticator) sa.client = sa.client.WithNonceGenerator(func() string { @@ -77,7 +81,7 @@ func TestSCRAM(t *testing.T) { conn := mnet.NewConnection(chanconn) - err = authenticator.Auth(context.Background(), &Config{Connection: conn}) + err = authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: conn}) assert.Nil(t, err, "Auth error: %v\n", err) // Verify that the first command sent is saslStart. diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index 1e6083c4646..e24cae4b454 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/v2/bson" @@ -64,7 +65,7 @@ func TestSpeculativeSCRAM(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Create a SCRAM authenticator and overwrite the nonce generator to make the conversation // deterministic. - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) @@ -151,7 +152,7 @@ func TestSpeculativeSCRAM(t *testing.T) { for _, tc := range testCases { t.Run(tc.mechanism, func(t *testing.T) { - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index ac50aa641bd..ee8d02ab658 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/v2/bson" @@ -33,7 +34,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response contains a reply to the speculative authentication attempt. The // driver should not send any more commands after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, @@ -79,7 +80,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response does not contain a reply to the speculative authentication attempt. // The driver should send an authenticate command after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 7aa064aaf02..608b13dda87 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" @@ -17,7 +18,7 @@ import ( // MongoDBX509 is the mechanism name for MongoDBX509. const MongoDBX509 = "MONGODB-X509" -func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) { +func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBX509Authenticator{User: cred.Username}, nil } @@ -51,7 +52,7 @@ func createFirstX509Message() bsoncore.Document { // Finish implements the SpeculativeConversation interface and is a no-op because an X509 conversation only has one // step. -func (c *x509Conversation) Finish(context.Context, *Config, bsoncore.Document) error { +func (c *x509Conversation) Finish(context.Context, *driver.AuthConfig, bsoncore.Document) error { return nil } @@ -61,7 +62,7 @@ func (a *MongoDBX509Authenticator) CreateSpeculativeConversation() (SpeculativeC } // Auth authenticates the provided connection by conducting an X509 authentication conversation. -func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error { +func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { requestDoc := createFirstX509Message() authCmd := operation. NewCommand(requestDoc). @@ -76,3 +77,8 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error return nil } + +// Reauth reauthenticates the connection. +func (a *MongoDBX509Authenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("X509 does not support reauthentication", nil) +} diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index ec6bfc0a807..3dd1b1918d3 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -297,6 +297,13 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthSource = "admin" } } + case "mongodb-oidc": + if u.AuthSource == "" { + u.AuthSource = dbName + if u.AuthSource == "" { + u.AuthSource = "$external" + } + } case "": // Only set auth source if there is a request for authentication via non-empty credentials. if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) { @@ -758,6 +765,10 @@ func (u *ConnString) validateAuth() error { if u.AuthMechanismProperties != nil { return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") } + case "mongodb-oidc": + if u.Password != "" { + return fmt.Errorf("password cannot be specified for MONGODB-OIDC") + } case "": if u.UsernameSet && u.Username == "" { return fmt.Errorf("username required if URI contains user info") diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index b22dabdc758..88995263e58 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -22,8 +22,66 @@ import ( "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" ) +// AuthConfig holds the information necessary to perform an authentication attempt. +// this was moved from the auth package to avoid a circular dependency. The auth package +// reexports this under the old name to avoid breaking the public api. +type AuthConfig struct { + Description description.Server + Connection *mnet.Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions +} + +// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be +// nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. +type IDPInfo struct { + Issuer string `bson:"issuer"` + ClientID string `bson:"clientId"` + RequestScopes []string `bson:"requestScopes"` +} + +// Authenticator handles authenticating a connection. The implementers of this interface +// are all in the auth package. Most authentication mechanisms do not allow for Reauth, +// but this is included in the interface so that whenever a new mechanism is added, it +// must be explicitly considered. +type Authenticator interface { + // Auth authenticates the connection. + Auth(context.Context, *AuthConfig) error + Reauth(context.Context, *AuthConfig) error +} + +// Cred is a user's credential. +type Cred struct { + Source string + Username string + Password string + PasswordSet bool + Props map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + // Deployment is implemented by types that can select a server from a deployment. type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 1938122fd66..4e1f2c78c58 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -26,6 +26,16 @@ type ChannelConn struct { Desc description.Server } +// OIDCTokenGenID implements the driver.Connection interface by returning the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) OIDCTokenGenID() uint64 { + return 0 +} + +// SetOIDCTokenGenID implements the driver.Connection interface by setting the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) SetOIDCTokenGenID(uint64) {} + // WriteWireMessage implements the driver.Connection interface. func (c *ChannelConn) Write(ctx context.Context, wm []byte) error { // Copy wm in case it came from a buffer pool. diff --git a/x/mongo/driver/mnet/connection.go b/x/mongo/driver/mnet/connection.go index 9beff5b1ee2..e02ecceadb1 100644 --- a/x/mongo/driver/mnet/connection.go +++ b/x/mongo/driver/mnet/connection.go @@ -30,6 +30,8 @@ type Describer interface { DriverConnectionID() int64 Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) } // Streamer represents a Connection that supports streaming wire protocol diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 3fc1436a2bb..8c2e3e8e9a3 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -303,6 +303,10 @@ type Operation struct { // of the operation do not contain a maxTimeMS field. OmitMaxTimeMS bool + // Authenticator is the authenticator to use for this operation when a reauthentication is + // required. + Authenticator Authenticator + // omitReadPreference is a boolean that indicates whether to omit the // read preference from the command. This omition includes the case // where a default read preference is used when the operation @@ -897,6 +901,28 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: + // 391 is the reauthentication required error code, so we will attempt a reauth and + // retry the operation, if it is successful. + if tt.Code == 391 { + if op.Authenticator != nil { + cfg := AuthConfig{ + Description: conn.Description(), + Connection: conn, + ClusterClock: op.Clock, + ServerAPI: op.ServerAPI, + } + if err := op.Authenticator.Reauth(ctx, &cfg); err != nil { + return fmt.Errorf("error reauthenticating: %w", err) + } + if op.Client != nil && op.Client.Committing { + // Apply majority write concern for retries + op.Client.UpdateCommitTransactionWriteConcern() + op.WriteConcern = op.Client.CurrentWc + } + resetForRetry(tt) + continue + } + } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { if err := op.Client.ClearPinnedResources(); err != nil { return err diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index f851282efe5..f6fd1f1ada8 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -21,6 +21,7 @@ import ( // AbortTransaction performs an abortTransaction operation. type AbortTransaction struct { + authenticator driver.Authenticator recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -66,6 +67,7 @@ func (at *AbortTransaction) Execute(ctx context.Context) error { WriteConcern: at.writeConcern, ServerAPI: at.serverAPI, Name: driverutil.AbortTransactionOp, + Authenticator: at.authenticator, }.Execute(ctx) } @@ -199,3 +201,13 @@ func (at *AbortTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Abort at.serverAPI = serverAPI return at } + +// Authenticator sets the authenticator to use for this operation. +func (at *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction { + if at == nil { + at = new(AbortTransaction) + } + + at.authenticator = authenticator + return at +} diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 69ef6e09b01..a80e8b035ec 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -24,6 +24,7 @@ import ( // Aggregate represents an aggregate operation. type Aggregate struct { + authenticator driver.Authenticator allowDiskUse *bool batchSize *int32 bypassDocumentValidation *bool @@ -110,6 +111,7 @@ func (a *Aggregate) Execute(ctx context.Context) error { IsOutputAggregate: a.hasOutputStage, Timeout: a.timeout, Name: driverutil.AggregateOp, + Authenticator: a.authenticator, }.Execute(ctx) } @@ -404,3 +406,13 @@ func (a *Aggregate) Timeout(timeout *time.Duration) *Aggregate { a.timeout = timeout return a } + +// Authenticator sets the authenticator to use for this operation. +func (a *Aggregate) Authenticator(authenticator driver.Authenticator) *Aggregate { + if a == nil { + a = new(Aggregate) + } + + a.authenticator = authenticator + return a +} diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 0b952b07cd7..776da1a0965 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -22,6 +22,7 @@ import ( // Command is used to run a generic operation. type Command struct { + authenticator driver.Authenticator command bsoncore.Document database string deployment driver.Deployment @@ -107,6 +108,7 @@ func (c *Command) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Logger: c.logger, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -219,3 +221,13 @@ func (c *Command) Logger(logger *logger.Logger) *Command { c.logger = logger return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { + if c == nil { + c = new(Command) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index e2e4e7b20bb..572fb5607b9 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -21,6 +21,7 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { + authenticator driver.Authenticator recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -65,6 +66,7 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, Name: driverutil.CommitTransactionOp, + Authenticator: ct.authenticator, }.Execute(ctx) } @@ -188,3 +190,13 @@ func (ct *CommitTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Comm ct.serverAPI = serverAPI return ct } + +// Authenticator sets the authenticator to use for this operation. +func (ct *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction { + if ct == nil { + ct = new(CommitTransaction) + } + + ct.authenticator = authenticator + return ct +} diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index c8480475f45..b3d201612ae 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -24,6 +24,7 @@ import ( // Count represents a count operation. type Count struct { + authenticator driver.Authenticator query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -125,6 +126,7 @@ func (c *Count) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Name: driverutil.CountOp, + Authenticator: c.authenticator, }.Execute(ctx) // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace @@ -298,3 +300,13 @@ func (c *Count) Timeout(timeout *time.Duration) *Count { c.timeout = timeout return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Count) Authenticator(authenticator driver.Authenticator) *Count { + if c == nil { + c = new(Count) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 2896e02cbfd..911ddb9b4be 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -21,6 +21,7 @@ import ( // Create represents a create operation. type Create struct { + authenticator driver.Authenticator capped *bool collation bsoncore.Document changeStreamPreAndPostImages bsoncore.Document @@ -78,6 +79,7 @@ func (c *Create) Execute(ctx context.Context) error { Selector: c.selector, WriteConcern: c.writeConcern, ServerAPI: c.serverAPI, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -400,3 +402,13 @@ func (c *Create) ClusteredIndex(ci bsoncore.Document) *Create { c.clusteredIndex = ci return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Create) Authenticator(authenticator driver.Authenticator) *Create { + if c == nil { + c = new(Create) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 5a49a7d1a64..e878ae9c29d 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -23,20 +23,21 @@ import ( // CreateIndexes performs a createIndexes operation. type CreateIndexes struct { - commitQuorum bsoncore.Value - indexes bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result CreateIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + commitQuorum bsoncore.Value + indexes bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result CreateIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateIndexesResult represents a createIndexes result returned by the server. @@ -116,6 +117,7 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { ServerAPI: ci.serverAPI, Timeout: ci.timeout, Name: driverutil.CreateIndexesOp, + Authenticator: ci.authenticator, }.Execute(ctx) } @@ -265,3 +267,13 @@ func (ci *CreateIndexes) Timeout(timeout *time.Duration) *CreateIndexes { ci.timeout = timeout return ci } + +// Authenticator sets the authenticator to use for this operation. +func (ci *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes { + if ci == nil { + ci = new(CreateIndexes) + } + + ci.authenticator = authenticator + return ci +} diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index dac561ff858..9b3a305ba94 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -22,18 +22,19 @@ import ( // CreateSearchIndexes performs a createSearchIndexes operation. type CreateSearchIndexes struct { - indexes bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result CreateSearchIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + indexes bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result CreateSearchIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateSearchIndexResult represents a single search index result in CreateSearchIndexesResult. @@ -116,6 +117,7 @@ func (csi *CreateSearchIndexes) Execute(ctx context.Context) error { Selector: csi.selector, ServerAPI: csi.serverAPI, Timeout: csi.timeout, + Authenticator: csi.authenticator, }.Execute(ctx) } @@ -237,3 +239,13 @@ func (csi *CreateSearchIndexes) Timeout(timeout *time.Duration) *CreateSearchInd csi.timeout = timeout return csi } + +// Authenticator sets the authenticator to use for this operation. +func (csi *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes { + if csi == nil { + csi = new(CreateSearchIndexes) + } + + csi.authenticator = authenticator + return csi +} diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index b9e21465c1d..9b9348dae11 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -24,25 +24,26 @@ import ( // Delete performs a delete operation type Delete struct { - comment bsoncore.Value - deletes []bsoncore.Document - ordered *bool - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - retry *driver.RetryMode - hint *bool - result DeleteResult - serverAPI *driver.ServerAPIOptions - let bsoncore.Document - timeout *time.Duration - logger *logger.Logger + authenticator driver.Authenticator + comment bsoncore.Value + deletes []bsoncore.Document + ordered *bool + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + retry *driver.RetryMode + hint *bool + result DeleteResult + serverAPI *driver.ServerAPIOptions + let bsoncore.Document + timeout *time.Duration + logger *logger.Logger } // DeleteResult represents a delete result returned by the server. @@ -115,6 +116,7 @@ func (d *Delete) Execute(ctx context.Context) error { Timeout: d.timeout, Logger: d.logger, Name: driverutil.DeleteOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -327,3 +329,13 @@ func (d *Delete) Logger(logger *logger.Logger) *Delete { return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Delete) Authenticator(authenticator driver.Authenticator) *Delete { + if d == nil { + d = new(Delete) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index 3ace8fc2edc..1f30b05248b 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -23,6 +23,7 @@ import ( // Distinct performs a distinct operation. type Distinct struct { + authenticator driver.Authenticator collation bsoncore.Document key *string query bsoncore.Document @@ -104,6 +105,7 @@ func (d *Distinct) Execute(ctx context.Context) error { ServerAPI: d.serverAPI, Timeout: d.timeout, Name: driverutil.DistinctOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -298,3 +300,13 @@ func (d *Distinct) Timeout(timeout *time.Duration) *Distinct { d.timeout = timeout return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Distinct) Authenticator(authenticator driver.Authenticator) *Distinct { + if d == nil { + d = new(Distinct) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index 98cbee7bb9b..e3cb059a504 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -23,18 +23,19 @@ import ( // DropCollection performs a drop operation. type DropCollection struct { - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropCollectionResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropCollectionResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropCollectionResult represents a dropCollection result returned by the server. @@ -104,6 +105,7 @@ func (dc *DropCollection) Execute(ctx context.Context) error { ServerAPI: dc.serverAPI, Timeout: dc.timeout, Name: driverutil.DropOp, + Authenticator: dc.authenticator, }.Execute(ctx) } @@ -222,3 +224,13 @@ func (dc *DropCollection) Timeout(timeout *time.Duration) *DropCollection { dc.timeout = timeout return dc } + +// Authenticator sets the authenticator to use for this operation. +func (dc *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection { + if dc == nil { + dc = new(DropCollection) + } + + dc.authenticator = authenticator + return dc +} diff --git a/x/mongo/driver/operation/drop_database.go b/x/mongo/driver/operation/drop_database.go index a10c02c63bc..9a724caacba 100644 --- a/x/mongo/driver/operation/drop_database.go +++ b/x/mongo/driver/operation/drop_database.go @@ -21,15 +21,16 @@ import ( // DropDatabase performs a dropDatabase operation type DropDatabase struct { - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + serverAPI *driver.ServerAPIOptions } // NewDropDatabase constructs and returns a new DropDatabase. @@ -55,6 +56,7 @@ func (dd *DropDatabase) Execute(ctx context.Context) error { WriteConcern: dd.writeConcern, ServerAPI: dd.serverAPI, Name: driverutil.DropDatabaseOp, + Authenticator: dd.authenticator, }.Execute(ctx) } @@ -154,3 +156,13 @@ func (dd *DropDatabase) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropDatab dd.serverAPI = serverAPI return dd } + +// Authenticator sets the authenticator to use for this operation. +func (dd *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase { + if dd == nil { + dd = new(DropDatabase) + } + + dd.authenticator = authenticator + return dd +} diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index ed068b45001..ce5bab8aa7b 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -23,19 +23,20 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { - index any - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index any + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropIndexesResult represents a dropIndexes result returned by the server. @@ -99,6 +100,7 @@ func (di *DropIndexes) Execute(ctx context.Context) error { ServerAPI: di.serverAPI, Timeout: di.timeout, Name: driverutil.DropIndexesOp, + Authenticator: di.authenticator, }.Execute(ctx) } @@ -237,3 +239,13 @@ func (di *DropIndexes) Timeout(timeout *time.Duration) *DropIndexes { di.timeout = timeout return di } + +// Authenticator sets the authenticator to use for this operation. +func (di *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes { + if di == nil { + di = new(DropIndexes) + } + + di.authenticator = authenticator + return di +} diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 53dd4870292..5061c620522 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -21,18 +21,19 @@ import ( // DropSearchIndex performs an dropSearchIndex operation. type DropSearchIndex struct { - index string - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result DropSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result DropSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropSearchIndexResult represents a dropSearchIndex result returned by the server. @@ -93,6 +94,7 @@ func (dsi *DropSearchIndex) Execute(ctx context.Context) error { Selector: dsi.selector, ServerAPI: dsi.serverAPI, Timeout: dsi.timeout, + Authenticator: dsi.authenticator, }.Execute(ctx) } @@ -212,3 +214,13 @@ func (dsi *DropSearchIndex) Timeout(timeout *time.Duration) *DropSearchIndex { dsi.timeout = timeout return dsi } + +// Authenticator sets the authenticator to use for this operation. +func (dsi *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex { + if dsi == nil { + dsi = new(DropSearchIndex) + } + + dsi.authenticator = authenticator + return dsi +} diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index 43b8a1201d4..df44fb44be8 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -20,15 +20,16 @@ import ( // EndSessions performs an endSessions operation. type EndSessions struct { - sessionIDs bsoncore.Document - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + sessionIDs bsoncore.Document + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + serverAPI *driver.ServerAPIOptions } // NewEndSessions constructs and returns a new EndSessions. @@ -61,6 +62,7 @@ func (es *EndSessions) Execute(ctx context.Context) error { Selector: es.selector, ServerAPI: es.serverAPI, Name: driverutil.EndSessionsOp, + Authenticator: es.authenticator, }.Execute(ctx) } @@ -161,3 +163,13 @@ func (es *EndSessions) ServerAPI(serverAPI *driver.ServerAPIOptions) *EndSession es.serverAPI = serverAPI return es } + +// Authenticator sets the authenticator to use for this operation. +func (es *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions { + if es == nil { + es = new(EndSessions) + } + + es.authenticator = authenticator + return es +} diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 468dbb610dd..803e2768c2a 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -24,6 +24,7 @@ import ( // Find performs a find operation. type Find struct { + authenticator driver.Authenticator allowDiskUse *bool allowPartialResults *bool awaitData *bool @@ -107,6 +108,7 @@ func (f *Find) Execute(ctx context.Context) error { Timeout: f.timeout, Logger: f.logger, Name: driverutil.FindOp, + Authenticator: f.authenticator, }.Execute(ctx) } @@ -547,3 +549,13 @@ func (f *Find) Logger(logger *logger.Logger) *Find { f.logger = logger return f } + +// Authenticator sets the authenticator to use for this operation. +func (f *Find) Authenticator(authenticator driver.Authenticator) *Find { + if f == nil { + f = new(Find) + } + + f.authenticator = authenticator + return f +} diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 3b3c3ab89ff..9939939f920 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -24,6 +24,7 @@ import ( // FindAndModify performs a findAndModify operation. type FindAndModify struct { + authenticator driver.Authenticator arrayFilters bsoncore.Array bypassDocumentValidation *bool collation bsoncore.Document @@ -142,6 +143,7 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { ServerAPI: fam.serverAPI, Timeout: fam.timeout, Name: driverutil.FindAndModifyOp, + Authenticator: fam.authenticator, }.Execute(ctx) } @@ -464,3 +466,13 @@ func (fam *FindAndModify) Timeout(timeout *time.Duration) *FindAndModify { fam.timeout = timeout return fam } + +// Authenticator sets the authenticator to use for this operation. +func (fam *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify { + if fam == nil { + fam = new(FindAndModify) + } + + fam.authenticator = authenticator + return fam +} diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index bf4bc914b0c..4e3749aef4b 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -37,6 +37,7 @@ const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. type Hello struct { + authenticator driver.Authenticator appname string compressors []string saslSupportedMechs string @@ -659,8 +660,16 @@ func (h *Hello) OmitMaxTimeMS(val bool) *Hello { if h == nil { h = new(Hello) } - h.omitMaxTimeMS = val + return h +} + +// Authenticator sets the authenticator to use for this operation. +func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello { + if h == nil { + h = new(Hello) + } + h.authenticator = authenticator return h } diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 5909767f130..8cd338c03b4 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -24,6 +24,7 @@ import ( // Insert performs an insert operation. type Insert struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value documents []bsoncore.Document @@ -114,6 +115,7 @@ func (i *Insert) Execute(ctx context.Context) error { Timeout: i.timeout, Logger: i.logger, Name: driverutil.InsertOp, + Authenticator: i.authenticator, }.Execute(ctx) } @@ -307,3 +309,13 @@ func (i *Insert) Logger(logger *logger.Logger) *Insert { i.logger = logger return i } + +// Authenticator sets the authenticator to use for this operation. +func (i *Insert) Authenticator(authenticator driver.Authenticator) *Insert { + if i == nil { + i = new(Insert) + } + + i.authenticator = authenticator + return i +} diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index b6b070510fa..db1ba560bd5 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -24,6 +24,7 @@ import ( // ListDatabases performs a listDatabases operation. type ListDatabases struct { + authenticator driver.Authenticator filter bsoncore.Document authorizedDatabases *bool nameOnly *bool @@ -165,6 +166,7 @@ func (ld *ListDatabases) Execute(ctx context.Context) error { ServerAPI: ld.serverAPI, Timeout: ld.timeout, Name: driverutil.ListDatabasesOp, + Authenticator: ld.authenticator, }.Execute(ctx) } @@ -327,3 +329,13 @@ func (ld *ListDatabases) Timeout(timeout *time.Duration) *ListDatabases { ld.timeout = timeout return ld } + +// Authenticator sets the authenticator to use for this operation. +func (ld *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases { + if ld == nil { + ld = new(ListDatabases) + } + + ld.authenticator = authenticator + return ld +} diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index 9bf8db86302..b746251dbc5 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -22,6 +22,7 @@ import ( // ListCollections performs a listCollections operation. type ListCollections struct { + authenticator driver.Authenticator filter bsoncore.Document nameOnly *bool authorizedCollections *bool @@ -83,6 +84,7 @@ func (lc *ListCollections) Execute(ctx context.Context) error { ServerAPI: lc.serverAPI, Timeout: lc.timeout, Name: driverutil.ListCollectionsOp, + Authenticator: lc.authenticator, }.Execute(ctx) } @@ -259,3 +261,13 @@ func (lc *ListCollections) Timeout(timeout *time.Duration) *ListCollections { lc.timeout = timeout return lc } + +// Authenticator sets the authenticator to use for this operation. +func (lc *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections { + if lc == nil { + lc = new(ListCollections) + } + + lc.authenticator = authenticator + return lc +} diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 57d90e46f9f..df5a90acf8d 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -21,18 +21,19 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { - batchSize *int32 - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - database string - deployment driver.Deployment - selector description.ServerSelector - retry *driver.RetryMode - crypt driver.Crypt - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + batchSize *int32 + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + database string + deployment driver.Deployment + selector description.ServerSelector + retry *driver.RetryMode + crypt driver.Crypt + serverAPI *driver.ServerAPIOptions + timeout *time.Duration result driver.CursorResponse } @@ -83,6 +84,7 @@ func (li *ListIndexes) Execute(ctx context.Context) error { ServerAPI: li.serverAPI, Timeout: li.timeout, Name: driverutil.ListIndexesOp, + Authenticator: li.authenticator, }.Execute(ctx) } @@ -221,3 +223,13 @@ func (li *ListIndexes) Timeout(timeout *time.Duration) *ListIndexes { li.timeout = timeout return li } + +// Authenticator sets the authenticator to use for this operation. +func (li *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes { + if li == nil { + li = new(ListIndexes) + } + + li.authenticator = authenticator + return li +} diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index a47de018173..612a5c1f9d7 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -25,6 +25,7 @@ import ( // Update performs an update operation. type Update struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value ordered *bool @@ -166,6 +167,7 @@ func (u *Update) Execute(ctx context.Context) error { Timeout: u.timeout, Logger: u.logger, Name: driverutil.UpdateOp, + Authenticator: u.authenticator, }.Execute(ctx) } @@ -413,3 +415,13 @@ func (u *Update) Logger(logger *logger.Logger) *Update { u.logger = logger return u } + +// Authenticator sets the authenticator to use for this operation. +func (u *Update) Authenticator(authenticator driver.Authenticator) *Update { + if u == nil { + u = new(Update) + } + + u.authenticator = authenticator + return u +} diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index a1539a6e483..979bb3bc795 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -21,19 +21,20 @@ import ( // UpdateSearchIndex performs a updateSearchIndex operation. type UpdateSearchIndex struct { - index string - definition bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result UpdateSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + definition bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result UpdateSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // UpdateSearchIndexResult represents a single index in the updateSearchIndexResult result. @@ -95,6 +96,7 @@ func (usi *UpdateSearchIndex) Execute(ctx context.Context) error { Selector: usi.selector, ServerAPI: usi.serverAPI, Timeout: usi.timeout, + Authenticator: usi.authenticator, }.Execute(ctx) } @@ -225,3 +227,13 @@ func (usi *UpdateSearchIndex) Timeout(timeout *time.Duration) *UpdateSearchIndex usi.timeout = timeout return usi } + +// Authenticator sets the authenticator to use for this operation. +func (usi *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex { + if usi == nil { + usi = new(UpdateSearchIndex) + } + + usi.authenticator = authenticator + return usi +} diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 553d6e2fdf6..1b4a89b80ad 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -735,6 +735,8 @@ func (m *mockConnection) SupportsStreaming() bool { return m.rCanStream func (m *mockConnection) CurrentlyStreaming() bool { return m.rStreaming } func (m *mockConnection) SetStreaming(streaming bool) { m.rStreaming = streaming } func (m *mockConnection) Stale() bool { return false } +func (m *mockConnection) OIDCTokenGenID() uint64 { return 0 } +func (m *mockConnection) SetOIDCTokenGenID(uint64) {} func (m *mockConnection) DriverConnectionID() int64 { return 0 } diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index a7583a5c155..4455516cae0 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -76,6 +76,9 @@ type connection struct { driverConnectionID int64 generation uint64 + // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate + // accessTokens in the OIDC authenticator cache. + oidcTokenGenID uint64 } // newConnection handles the creation of a connection. It does not connect the connection. @@ -558,6 +561,8 @@ type Connection struct { refCount int cleanupPoolFn func() + oidcTokenGenID uint64 + // cleanupServerFn resets the server state when a connection is returned to the connection pool // via Close() or expired via Expire(). cleanupServerFn func() @@ -812,3 +817,21 @@ func configureTLS(ctx context.Context, } return client, nil } + +// OIDCTokenGenID returns the OIDC token generation ID. +func (c *Connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +// SetOIDCTokenGenID sets the OIDC token generation ID. +func (c *Connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} + +func (c *connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index ede32601fc9..5be0731a74d 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -96,6 +96,30 @@ func NewConfig(opts *options.ClientOptionsBuilder, clock *session.ClusterClock) // config for building non-default deployments. Server and topology options are // not honored if a custom deployment is used. func NewConfigFromOptions(opts *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { + // Auth & Database & Password & Username + if opts.Auth != nil { + cred := &auth.Cred{ + Username: opts.Auth.Username, + Password: opts.Auth.Password, + PasswordSet: opts.Auth.PasswordSet, + Props: opts.Auth.AuthMechanismProperties, + Source: opts.Auth.AuthSource, + } + mechanism := opts.Auth.AuthMechanism + authenticator, err := auth.CreateAuthenticator(mechanism, cred, opts.HTTPClient) + if err != nil { + return nil, err + } + return NewConfigFromOptionsWithAuthenticator(opts, clock, authenticator) + } + return NewConfigFromOptionsWithAuthenticator(opts, clock, nil) +} + +// NewConfigFromOptionsWithAuthenticator will translate data from client options into a topology config for building non-default deployments. +// Server and topology options are not honored if a custom deployment is used. It uses a passed in +// authenticator to authenticate the connection. +func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { + var serverAPI *driver.ServerAPIOptions clientOptsBldr := options.ClientOptionsBuilder{ @@ -217,11 +241,6 @@ func NewConfigFromOptions(opts *options.ClientOptions, clock *session.ClusterClo } } - authenticator, err := auth.CreateAuthenticator(mechanism, cred) - if err != nil { - return nil, err - } - handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, @@ -229,7 +248,6 @@ func NewConfigFromOptions(opts *options.ClientOptions, clock *session.ClusterClo ServerAPI: serverAPI, LoadBalanced: loadBalanced, ClusterClock: clock, - HTTPClient: opts.HTTPClient, } if mechanism == "" {