diff --git a/docstore/mongodocstore/urls.go b/docstore/mongodocstore/urls.go index 3e7ce282e2..b6d2ed14d4 100644 --- a/docstore/mongodocstore/urls.go +++ b/docstore/mongodocstore/urls.go @@ -34,27 +34,31 @@ func init() { // defaultDialer dials a default Mongo server based on the environment variable // MONGO_SERVER_URL. type defaultDialer struct { - init sync.Once - opener *URLOpener - err error + mongoServerURL string + mu sync.Mutex + opener *URLOpener + err error } func (o *defaultDialer) OpenCollectionURL(ctx context.Context, u *url.URL) (*docstore.Collection, error) { - o.init.Do(func() { - serverURL := os.Getenv("MONGO_SERVER_URL") - if serverURL == "" { - o.err = errors.New("MONGO_SERVER_URL environment variable is not set") - return - } - client, err := Dial(ctx, serverURL) + o.mu.Lock() + defer o.mu.Unlock() + currentEnv := os.Getenv("MONGO_SERVER_URL") + + if currentEnv == "" { + o.err = errors.New("MONGO_SERVER_URL environment variable is not set") + return nil, fmt.Errorf("open collection %s: %v", u, o.err) + } + + // If MONGO_SERVER_URL has been updated, then update o.opener as well + if currentEnv != o.mongoServerURL { + client, err := Dial(ctx, currentEnv) if err != nil { - o.err = fmt.Errorf("failed to dial default Mongo server at %q: %v", serverURL, err) - return + o.err = fmt.Errorf("failed to dial default Mongo server at %q: %v", currentEnv, err) + return nil, fmt.Errorf("open collection %s: %v", u, o.err) } + o.mongoServerURL = currentEnv o.opener = &URLOpener{Client: client} - }) - if o.err != nil { - return nil, fmt.Errorf("open collection %s: %v", u, o.err) } return o.opener.OpenCollectionURL(ctx, u) } diff --git a/docstore/mongodocstore/urls_test.go b/docstore/mongodocstore/urls_test.go index 3c3bb55a7f..61d79490c5 100644 --- a/docstore/mongodocstore/urls_test.go +++ b/docstore/mongodocstore/urls_test.go @@ -16,6 +16,7 @@ package mongodocstore import ( "context" + "net/url" "os" "testing" @@ -63,3 +64,73 @@ func TestOpenCollectionURL(t *testing.T) { } } } + +func TestDefaultDialerOpenCollectionURL(t *testing.T) { + // Defer cleanup + oldURLVal := os.Getenv("MONGO_SERVER_URL") + defer os.Setenv("MONGO_SERVER_URL", oldURLVal) + + tests := []struct { + name string + currentMongoServerURL string + currentWantErr bool + newMongoServerURL string + newWantErr bool + }{ + { + name: "fail when MONGO_SERVER_URL is empty / unset", + currentMongoServerURL: "", + currentWantErr: true, + newMongoServerURL: "", + newWantErr: true, + }, + { + name: "fail when updated MONGO_SERVER_URL is empty / unset", + currentMongoServerURL: "mongodb://localhost", + currentWantErr: false, + newMongoServerURL: "", + newWantErr: true, + }, + { + name: "pass when MONGO_SERVER_URL is updated to new value", + currentMongoServerURL: "mongodb://localhost", + currentWantErr: false, + newMongoServerURL: "mongodb://localhost:27017", + newWantErr: false, + }, + } + + // Set starting conditions + d := new(defaultDialer) + ctx := context.Background() + mongoURLString := "mongo://mydb/mycollection" + u, err := url.Parse(mongoURLString) + if err != nil { + t.Error(err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Set MONGO_SERVER_URL + os.Setenv("MONGO_SERVER_URL", test.currentMongoServerURL) + _, err = d.OpenCollectionURL(ctx, u) + if err != nil && !test.currentWantErr { + t.Error(err) + } + + // Update MONGO_SERVER_URL + os.Setenv("MONGO_SERVER_URL", test.newMongoServerURL) + _, err = d.OpenCollectionURL(ctx, u) + if err != nil && !test.newWantErr { + t.Error(err) + } + + // Check if the MONGO_SERVER_URL was updated after rotation + if !test.newWantErr { + if d.mongoServerURL != test.newMongoServerURL { + t.Errorf("expected updated MONGO_SERVER_URL to be set to: %s, but got: %s", test.newMongoServerURL, d.mongoServerURL) + } + } + }) + } +}