Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions balancer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ type acBalancerWrapper struct {
// dropped or updated. This is required as closures can't be compared for
// equality.
healthData *healthData

shutdownMu sync.Mutex
shutdownCh chan struct{}
activeGofuncs sync.WaitGroup
}

// healthData holds data related to health state reporting.
Expand Down Expand Up @@ -347,16 +351,45 @@ func (acbw *acBalancerWrapper) String() string {
}

func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.ac.updateAddrs(addrs)
acbw.goFunc(func(shutdown <-chan struct{}) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we changed this to run the acbw.ac.updateAddrs(addrs) function in a go routine??

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's basically a "bubbled-up" goroutine. Previously, the goroutine was spawned in updateAddrs itself (line 1021). But, as we now need to track those, I figured it would be most appropriate to do it here. Another option would be to somehow push this down into updateAddrs itself, by passing the acBalancerWrapper pointer, or a function pointer to acbw.goFunc or sth. along those lines, and then use that to spawn the goroutine there:

Suggested change
acbw.goFunc(func(shutdown <-chan struct{}) {
acbw.ac.updateAddrs(acbw, addrs)

Then we could write line 1021 of updateAddrs like so:

	acbw.goFunc(ac.resetTransportAndUnlock)

acbw.ac.updateAddrs(shutdown, addrs)
})
}

func (acbw *acBalancerWrapper) Connect() {
go acbw.ac.connect()
acbw.goFunc(acbw.ac.connect)
}

func (acbw *acBalancerWrapper) goFunc(fn func(shutdown <-chan struct{})) {
acbw.shutdownMu.Lock()
defer acbw.shutdownMu.Unlock()

shutdown := acbw.shutdownCh
if shutdown == nil {
shutdown = make(chan struct{})
acbw.shutdownCh = shutdown
}

acbw.activeGofuncs.Add(1)
go func() {
defer acbw.activeGofuncs.Done()
fn(shutdown)
}()
}

func (acbw *acBalancerWrapper) Shutdown() {
acbw.closeProducers()
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)

acbw.shutdownMu.Lock()
defer acbw.shutdownMu.Unlock()

shutdown := acbw.shutdownCh
acbw.shutdownCh = nil
if shutdown != nil {
close(shutdown)
acbw.activeGofuncs.Wait()
}
}

// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
Expand Down
75 changes: 48 additions & 27 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -925,25 +925,24 @@ func (cc *ClientConn) incrCallsFailed() {
// connect starts creating a transport.
// It does nothing if the ac is not IDLE.
// TODO(bar) Move this to the addrConn section.
func (ac *addrConn) connect() error {
func (ac *addrConn) connect(abort <-chan struct{}) {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
if logger.V(2) {
logger.Infof("connect called on shutdown addrConn; ignoring.")
}
ac.mu.Unlock()
return errConnClosing
return
}
if ac.state != connectivity.Idle {
if logger.V(2) {
logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state)
}
ac.mu.Unlock()
return nil
return
}

ac.resetTransportAndUnlock()
return nil
ac.resetTransportAndUnlock(abort)
}

// equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
Expand All @@ -962,7 +961,7 @@ func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {

// updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
func (ac *addrConn) updateAddrs(abort <-chan struct{}, addrs []resolver.Address) {
addrs = copyAddresses(addrs)
limit := len(addrs)
if limit > 5 {
Expand Down Expand Up @@ -1018,7 +1017,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {

// Since we were connecting/connected, we should start a new connection
// attempt.
go ac.resetTransportAndUnlock()
ac.resetTransportAndUnlock(abort)
}

// getServerName determines the serverName to be used in the connection
Expand Down Expand Up @@ -1249,9 +1248,17 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
// resetTransportAndUnlock unconditionally connects the addrConn.
//
// ac.mu must be held by the caller, and this function will guarantee it is released.
func (ac *addrConn) resetTransportAndUnlock() {
acCtx := ac.ctx
if acCtx.Err() != nil {
func (ac *addrConn) resetTransportAndUnlock(abort <-chan struct{}) {
ctx, cancel := context.WithCancel(ac.ctx)
go func() {
select {
case <-abort:
cancel()
case <-ctx.Done():
}
}()

if ctx.Err() != nil {
ac.mu.Unlock()
return
}
Expand Down Expand Up @@ -1279,12 +1286,12 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.updateConnectivityState(connectivity.Connecting, nil)
ac.mu.Unlock()

if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
if err := ac.tryAllAddrs(ctx, addrs, connectDeadline); err != nil {
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy
// to ensure one resolution request per pass instead of per subconn failure.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
ac.mu.Lock()
if acCtx.Err() != nil {
if ctx.Err() != nil {
// addrConn was torn down.
ac.mu.Unlock()
return
Expand All @@ -1305,13 +1312,13 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.mu.Unlock()
case <-b:
timer.Stop()
case <-acCtx.Done():
case <-ctx.Done():
timer.Stop()
return
}

ac.mu.Lock()
if acCtx.Err() == nil {
if ctx.Err() == nil {
ac.updateConnectivityState(connectivity.Idle, err)
}
ac.mu.Unlock()
Expand Down Expand Up @@ -1366,6 +1373,9 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
// new transport.
func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
addr.ServerName = ac.cc.getServerName(addr)

var healthCheckStarted atomic.Bool
healthCheckDone := make(chan struct{})
hctx, hcancel := context.WithCancel(ctx)

onClose := func(r transport.GoAwayReason) {
Expand Down Expand Up @@ -1394,6 +1404,9 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
// Always go idle and wait for the LB policy to initiate a new
// connection attempt.
ac.updateConnectivityState(connectivity.Idle, nil)
if healthCheckStarted.Load() {
<-healthCheckDone
}
}

connectCtx, cancel := context.WithDeadline(ctx, connectDeadline)
Expand All @@ -1406,29 +1419,35 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
logger.Infof("Creating new client transport to %q: %v", addr, err)
}
// newTr is either nil, or closed.
hcancel()
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
return err
}

ac.mu.Lock()
defer ac.mu.Unlock()
acMu := &ac.mu
acMu.Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the original code did not work? this seems unnecessarily complicated and does the same thing. Or am I missing something.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a way to make the defer do the unlock conditionally, as the code might need to unlock it before returning (see lines 1441 and 1442). We can achieve the same e.g. with a boolean variable, if you prefer.

I think using a pointer for this is good, because, if you forget the nil check, it will panic with an easy to understand stack trace. Whereas if you forget to check a boolean, you'd do a double-unlock and there's a chance that you get weirder problems.

defer func() {
if acMu != nil {
acMu.Unlock()
}
}()
if ctx.Err() != nil {
// This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not
// have been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose()

// We unlock ac.mu because newTr.Close() calls onClose()
// inline, which requires locking ac.mu.
//
acMu.Unlock()
acMu = nil

// The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error.
//
// This can also happen when updateAddrs is called during a connection
// attempt.
go newTr.Close(transport.ErrConnClosing)
newTr.Close(transport.ErrConnClosing)
return nil
}
if hctx.Err() != nil {
Expand All @@ -1440,7 +1459,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
}
ac.curAddr = addr
ac.transport = newTr
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
healthCheckStarted.Store(ac.startHealthCheck(hctx, healthCheckDone)) // Will set state to READY if appropriate.
return nil
}

Expand All @@ -1456,7 +1475,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
// It sets addrConn to READY if the health checking stream is not started.
//
// Caller must hold ac.mu.
func (ac *addrConn) startHealthCheck(ctx context.Context) {
func (ac *addrConn) startHealthCheck(ctx context.Context, done chan<- struct{}) bool {
var healthcheckManagingState bool
defer func() {
if !healthcheckManagingState {
Expand All @@ -1465,22 +1484,22 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
}()

if ac.cc.dopts.disableHealthCheck {
return
return false
}
healthCheckConfig := ac.cc.healthCheckConfig()
if healthCheckConfig == nil {
return
return false
}
if !ac.scopts.HealthCheckEnabled {
return
return false
}
healthCheckFunc := internal.HealthCheckFunc
if healthCheckFunc == nil {
// The health package is not imported to set health check function.
//
// TODO: add a link to the health check doc in the error message.
channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.")
return
return false
}

healthcheckManagingState = true
Expand All @@ -1506,6 +1525,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
}
// Start the health checking stream.
go func() {
defer close(done)
err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
if err != nil {
if status.Code(err) == codes.Unimplemented {
Expand All @@ -1515,6 +1535,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
}
}
}()
return true
}

func (ac *addrConn) resetConnectBackoff() {
Expand Down
5 changes: 5 additions & 0 deletions internal/balancer/gracefulswitch/gracefulswitch.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ type Balancer struct {
// balancerCurrent before the UpdateSubConnState is called on the
// balancerCurrent.
currentMu sync.Mutex

pendingSwaps sync.WaitGroup
}

// swap swaps out the current lb with the pending lb and updates the ClientConn.
Expand All @@ -76,7 +78,9 @@ func (gsb *Balancer) swap() {
cur := gsb.balancerCurrent
gsb.balancerCurrent = gsb.balancerPending
gsb.balancerPending = nil
gsb.pendingSwaps.Add(1)
go func() {
defer gsb.pendingSwaps.Done()
gsb.currentMu.Lock()
defer gsb.currentMu.Unlock()
cur.Close()
Expand Down Expand Up @@ -274,6 +278,7 @@ func (gsb *Balancer) Close() {

currentBalancerToClose.Close()
pendingBalancerToClose.Close()
gsb.pendingSwaps.Wait()
}

// balancerWrapper wraps a balancer.Balancer, and overrides some Balancer
Expand Down
10 changes: 10 additions & 0 deletions internal/testutils/pipe_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package testutils

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -81,11 +82,20 @@ func (p *PipeListener) Addr() net.Addr {
// Dialer dials a connection.
func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) {
return func(string, time.Duration) (net.Conn, error) {
return p.ContextDialer()(context.Background(), "")
}
}

// ContextDialer dials a using a context.
func (p *PipeListener) ContextDialer() func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, _ string) (net.Conn, error) {
connChan := make(chan net.Conn)
select {
case p.c <- connChan:
case <-p.done:
return nil, errClosed
case <-ctx.Done():
return nil, context.Cause(ctx)
}
conn, ok := <-connChan
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion test/clientconn_state_transition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s
client, err := grpc.NewClient("passthrough:///",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
grpc.WithDialer(pl.Dialer()),
grpc.WithContextDialer(pl.ContextDialer()),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{},
MinConnectTimeout: 100 * time.Millisecond,
Expand Down
Loading