Skip to content

Commit

Permalink
GODRIVER-2432 Improve panic handling in background processes (#1471)
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez authored Nov 16, 2023
1 parent 3f6e80a commit b12ee6d
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 11 deletions.
9 changes: 2 additions & 7 deletions x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,7 @@ func (s *Server) update() {
checkNow := s.checkNow
done := s.done

defer func() {
_ = recover()
}()
defer logUnexpectedFailure(s.cfg.logger, "Encountered unexpected failure updating server")

closeServer := func() {
s.subLock.Lock()
Expand Down Expand Up @@ -683,10 +681,7 @@ func (s *Server) updateDescription(desc description.Server) {
return
}

defer func() {
// ¯\_(ツ)_/¯
_ = recover()
}()
defer logUnexpectedFailure(s.cfg.logger, "Encountered unexpected failure updating server description")

// Anytime we update the server description to something other than "unknown", set the pool to
// "ready". Do this before updating the description so that connections can be checked out as
Expand Down
33 changes: 29 additions & 4 deletions x/mongo/driver/topology/topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ func logServerSelectionFailed(
logger.KeyFailure, err.Error())
}

// logUnexpectedFailure is a defer-recover function for logging unexpected
// failures encountered while maintaining a topology.
//
// Most topology maintenance actions, such as updating a server, should not take
// down a client's application. This function provides a best-effort to log
// unexpected failures. If the logger passed to this function is nil, then the
// recovery will be silent.
func logUnexpectedFailure(log *logger.Logger, msg string, callbacks ...func()) {
r := recover()
if r == nil {
return
}

defer func() {
for _, clbk := range callbacks {
clbk()
}
}()

if log == nil {
return
}

log.Print(logger.LevelInfo, logger.ComponentTopology, fmt.Sprintf("%s: %v", msg, r))
}

// Connect initializes a Topology and starts the monitoring process. This function
// must be called to properly monitor the topology.
func (t *Topology) Connect() error {
Expand Down Expand Up @@ -768,12 +794,11 @@ func (t *Topology) pollSRVRecords(hosts string) {
defer pollTicker.Stop()
t.pollHeartbeatTime.Store(false)
var doneOnce bool
defer func() {
// ¯\_(ツ)_/¯
if r := recover(); r != nil && !doneOnce {
defer logUnexpectedFailure(t.cfg.logger, "Encountered unexpected failure polling SRV records", func() {
if !doneOnce {
<-t.pollingDone
}
}()
})

for {
select {
Expand Down
81 changes: 81 additions & 0 deletions x/mongo/driver/topology/topology_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package topology

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
Expand All @@ -19,6 +21,7 @@ import (

"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/logger"
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/internal/spectest"
"go.mongodb.org/mongo-driver/mongo/address"
Expand Down Expand Up @@ -1195,3 +1198,81 @@ func BenchmarkSelectServerFromDescription(b *testing.B) {
})
}
}

func TestLogUnexpectedFailure(t *testing.T) {
t.Parallel()

// newIOLogger will log data using an io sink.
newIOLogger := func() (*logger.Logger, *bytes.Buffer, *bufio.Writer) {
buf := bytes.NewBuffer(nil)
w := bufio.NewWriter(buf)

ioSink := logger.NewIOSink(w)

ioLogger, err := logger.New(ioSink, logger.DefaultMaxDocumentLength, map[logger.Component]logger.Level{
logger.ComponentTopology: logger.LevelDebug,
})

assert.NoError(t, err)

return ioLogger, buf, w
}

// newNilLogger will return a nil logger with empty buffer and writer.
newNilLogger := func() (*logger.Logger, *bytes.Buffer, *bufio.Writer) {
return nil, &bytes.Buffer{}, &bufio.Writer{}
}

tests := []struct {
name string
msg string
newLogger func() (*logger.Logger, *bytes.Buffer, *bufio.Writer)
panicValue interface{}
want interface{} // Either a string or nil
}{
{
name: "nil logger",
msg: "",
newLogger: newNilLogger,
panicValue: 1,
want: nil,
},
{
name: "valid logger",
msg: "test",
newLogger: newIOLogger,
panicValue: 1,
want: "test: 1",
},
{
name: "valid logger with error panic",
msg: "test",
newLogger: newIOLogger,
panicValue: errors.New("err"),
want: "test: err",
},
}

for _, test := range tests {
test := test

t.Run(test.name, func(t *testing.T) {
t.Parallel()

log, buf, w := test.newLogger()

func() {
defer logUnexpectedFailure(log, test.msg)

panic(test.panicValue)
}()

assert.NoError(t, w.Flush())

got := map[string]interface{}{}
_ = json.Unmarshal(buf.Bytes(), &got)

assert.Equal(t, test.want, got[logger.KeyMessage])
})
}
}

0 comments on commit b12ee6d

Please sign in to comment.