Skip to content

Commit

Permalink
use a map to accurately track dials in flight
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Sep 4, 2023
1 parent 2dea484 commit b1c30cf
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ type dialWorker struct {
connected bool
// dq is used to pace dials to different addresses of the peer
dq *dialQueue
// dialsInFlight are the addresses with dials pending completion.
dialsInFlight int
// dialsInFlight are the addresses with dials pending completion. We use this to schedule new dials
// and to cleanup all pending dials when closing the loop
dialsInFlight map[string]bool
// totalDials is used to track number of dials made by this worker for metrics
totalDials int

Expand All @@ -107,6 +108,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dia
pendingRequests: make(map[*pendRequest]struct{}),
trackedDials: make(map[string]*addrDial),
resch: make(chan dialResult),
dialsInFlight: make(map[string]bool),
cl: cl,
}
}
Expand All @@ -131,7 +133,7 @@ func (w *dialWorker) loop() {
}
timerRunning = false
if w.dq.Len() > 0 {
if w.dialsInFlight == 0 && !w.connected {
if len(w.dialsInFlight) == 0 && !w.connected {
// if there are no dials in flight, trigger the next dials immediately
dialTimer.Reset(startTime)
} else {
Expand All @@ -156,23 +158,14 @@ loop:
if !ok {
return
}
// We have received a new request. If we do not have a suitable connection,
// track this dialRequest with a pendRequest.
// Enqueue the peer's addresses relevant to this request in dq and
// track dials to the addresses relevant to this request.

// Check if we have a suitable connection already
c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
if c != nil || err != nil {
req.resch <- dialResponse{conn: c, err: err}
continue loop
}

addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer)
if err != nil {
req.resch <- dialResponse{err: &DialError{Peer: w.peer, DialErrors: addrErrs, Cause: err}}
continue loop
}
w.addNewRequest(req, addrs, addrErrs)
// we don't have any suitable connection, add the request to the worker loop
w.addNewRequest(req)
scheduleNextDial()

case <-dialTimer.Ch():
Expand All @@ -196,7 +189,7 @@ loop:
// Errored without attempting a dial. This happens in case of backoff.
w.dispatchError(ad, err)
} else {
w.dialsInFlight++
w.dialsInFlight[string(ad.addr.Bytes())] = true
w.totalDials++
}
}
Expand All @@ -215,9 +208,7 @@ loop:
if res.Conn != nil {
res.Conn.Close()
}
// It is better to decrement the dials in flight and schedule one extra dial
// than risking not closing the worker loop on cleanup
w.dialsInFlight--
delete(w.dialsInFlight, string(res.Addr.Bytes()))
continue
}

Expand All @@ -227,7 +218,7 @@ loop:
continue
}

w.dialsInFlight--
delete(w.dialsInFlight, string(res.Addr.Bytes()))
// We're recording any error as a failure here.
// Notably, this also applies to cancelations (i.e. if another dial attempt was faster).
// This is ok since the black hole detector uses a very low threshold (5%).
Expand All @@ -243,24 +234,29 @@ loop:
}
}

// addNewRequest adds a new dial request to the worker loop. If the request has no pending dials, a response
// is sent immediately otherwise it is tracked in pendingRequests
func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrErrs []TransportError) {
// check if a dial to any of the addrs has succeeded already
// addNewRequest adds a new dial request to the worker loop. If the request has a valid connection or all relevant
// dials have failed, the request is handled immediately, otherwise it is added to pendingRequests.
func (w *dialWorker) addNewRequest(req dialRequest) {
addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer)
if err != nil {
req.resch <- dialResponse{err: &DialError{Peer: w.peer, DialErrors: addrErrs, Cause: err}}
return
}

// check if a dial to any of the relevant address has succeeded already
for _, addr := range addrs {
if ad, ok := w.trackedDials[string(addr.Bytes())]; ok {
if ad.conn != nil {
// dial to this addr was successful, complete the request
req.resch <- dialResponse{conn: ad.conn}
return
}
}
}

// get the delays to dial these addrs from the swarms dialRanker
// no dial has succeeded, get the delays to dial the addrs
simConnect, _, _ := network.GetSimultaneousConnect(req.ctx)
addrRanking := w.rankAddrs(addrs, simConnect)

// create the pending request object
pr := &pendRequest{
req: req,
err: &DialError{Peer: w.peer, DialErrors: addrErrs},
Expand All @@ -273,7 +269,7 @@ func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrEr
for _, adelay := range addrRanking {
ad, ok := w.trackedDials[string(adelay.Addr.Bytes())]
if !ok {
// new address, track and enqueue
// new address, track and enqueue for dialing
now := time.Now()
w.trackedDials[string(adelay.Addr.Bytes())] = &addrDial{
addr: adelay.Addr,
Expand All @@ -292,8 +288,9 @@ func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrEr
}

if !ad.dialed {
// we haven't dialed this address. update the ad.ctx to have simultaneous connect values
// set correctly
// We are tracking a dial to this address but we haven't dialled it already.
// If the new request is a holepunching request, update the context and the element in the
// dial queue
if isSimConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); isSimConnect {
if wasSimConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !wasSimConnect {
ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason)
Expand Down Expand Up @@ -425,16 +422,19 @@ func (w *dialWorker) cleanup() {
if w.s.metricsTracer != nil {
w.s.metricsTracer.DialCompleted(w.connected, w.totalDials)
}
for w.dialsInFlight > 0 {
for len(w.dialsInFlight) > 0 {
res := <-w.resch
if res.Kind != DialFailed && res.Kind != DialSuccessful {
continue
}
// We're recording any error as a failure here.
// Notably, this also applies to cancelations (i.e. if another dial attempt was faster).
// This is ok since the black hole detector uses a very low threshold (5%).
w.s.bhd.RecordResult(res.Addr, res.Err == nil)
if res.Conn != nil {
res.Conn.Close()
}
w.dialsInFlight--
delete(w.dialsInFlight, string(res.Addr.Bytes()))
}
}

Expand Down

0 comments on commit b1c30cf

Please sign in to comment.