From 2364173cedda65030a8a5f2110416387deb0083e Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Fri, 26 Jul 2024 09:52:44 +0800 Subject: [PATCH 1/4] Fix our-of-order twcc fb cause by rtx blocked Fix #2830. The TrackRemote.Read could block in readRTP if the buffer is empty then rtx packets arrival before next media rtp packet will be readed after the next media rtp packet and cause out-of-order fb and mess up remote peer's bandwidth estimation. --- rtpreceiver.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rtpreceiver.go b/rtpreceiver.go index ea3dddc0e0b..28d72246da3 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -424,7 +424,7 @@ func (r *RTPReceiver) receiveForRtx(ssrc SSRC, rsid string, streamInfo *intercep track.repairInterceptor = rtpInterceptor track.repairRtcpReadStream = rtcpReadStream track.repairRtcpInterceptor = rtcpInterceptor - track.repairStreamChannel = make(chan rtxPacketWithAttributes) + track.repairStreamChannel = make(chan rtxPacketWithAttributes, 50) go func() { for { @@ -474,6 +474,8 @@ func (r *RTPReceiver) receiveForRtx(ssrc SSRC, rsid string, streamInfo *intercep r.rtxPool.Put(b) // nolint:staticcheck return case track.repairStreamChannel <- rtxPacketWithAttributes{pkt: b[:i-2], attributes: attributes, pool: &r.rtxPool}: + default: + // skip the RTX packet if the repair stream channel is full, could be blocked in the application's read loop } } }() From 9836d583519405737726aa6f0dcbfe08e3007cab Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Thu, 1 Aug 2024 09:53:03 -0400 Subject: [PATCH 2/4] Enable tests for /v3 branch Resolves #2841 --- .github/workflows/browser-e2e.yaml | 2 ++ .github/workflows/codeql-analysis.yml | 1 + .github/workflows/examples-tests.yaml | 2 ++ .github/workflows/lint.yaml | 20 -------------------- .github/workflows/test.yaml | 11 ++++------- .github/workflows/tidy-check.yaml | 1 + 6 files changed, 10 insertions(+), 27 deletions(-) delete mode 100644 .github/workflows/lint.yaml diff --git a/.github/workflows/browser-e2e.yaml b/.github/workflows/browser-e2e.yaml index b9f78599a03..e73dabb55f8 100644 --- a/.github/workflows/browser-e2e.yaml +++ b/.github/workflows/browser-e2e.yaml @@ -5,9 +5,11 @@ on: pull_request: branches: - master + - v3 push: branches: - master + - v3 jobs: e2e-test: diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index ea9b825e2a3..b6d25166697 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -20,6 +20,7 @@ on: pull_request: branches: - master + - v3 paths: - '**.go' diff --git a/.github/workflows/examples-tests.yaml b/.github/workflows/examples-tests.yaml index 298e5c29f9d..9ff198ab261 100644 --- a/.github/workflows/examples-tests.yaml +++ b/.github/workflows/examples-tests.yaml @@ -5,9 +5,11 @@ on: pull_request: branches: - master + - v3 push: branches: - master + - v3 jobs: pion-to-pion-test: diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml deleted file mode 100644 index 5dd3a9939a3..00000000000 --- a/.github/workflows/lint.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# -# DO NOT EDIT THIS FILE -# -# It is automatically copied from https://github.com/pion/.goassets repository. -# If this repository should have package specific CI config, -# remove the repository name from .goassets/.github/workflows/assets-sync.yml. -# -# If you want to update the shared CI config, send a PR to -# https://github.com/pion/.goassets instead of this repository. -# -# SPDX-FileCopyrightText: 2023 The Pion community -# SPDX-License-Identifier: MIT - -name: Lint -on: - pull_request: - -jobs: - lint: - uses: pion/.goassets/.github/workflows/lint.reusable.yml@master diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 31aada4afe4..9d8ef233515 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,6 +16,7 @@ on: push: branches: - master + - v3 pull_request: jobs: @@ -23,21 +24,17 @@ jobs: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: - go: ['1.20', '1.19'] # auto-update/supported-go-version-list + go: ["1.20", "1.19"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} + secrets: inherit test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: - go: ['1.20', '1.19'] # auto-update/supported-go-version-list + go: ["1.20", "1.19"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} - - test-wasm: - uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master - with: - go-version: '1.20' # auto-update/latest-go-version diff --git a/.github/workflows/tidy-check.yaml b/.github/workflows/tidy-check.yaml index 4d346d4fd79..a1b4d43663b 100644 --- a/.github/workflows/tidy-check.yaml +++ b/.github/workflows/tidy-check.yaml @@ -17,6 +17,7 @@ on: push: branches: - master + - v3 jobs: tidy: From b8d3a7bba7331a78d33f64f1d11c6d9e3d70d15d Mon Sep 17 00:00:00 2001 From: Juan Navarro Date: Thu, 1 Aug 2024 12:04:23 +0200 Subject: [PATCH 3/4] Fix disordered RIDs in SDP Map iteration order is not guaranteed by Go, so it's an error to iterate over a map in places where maintaining the same ordering is important. This change replaces the map of simulcastRid{} with an array of the same type. The simulcastRid{} type is extended to hold the rid-id which previously was used as the key in the map. Accesses to the map are replaced with range loops to find the desired rid-id for each case. Fixes #2838 --- peerconnection.go | 2 +- sdp.go | 37 +++++++++++++++++++++---------------- sdp_test.go | 33 ++++++++++++++++++++++----------- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/peerconnection.go b/peerconnection.go index c216f2dbe32..d00f3baf19c 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -2500,7 +2500,7 @@ func (pc *PeerConnection) generateMatchedSDP(transceivers []*RTPTransceiver, use mediaTransceivers := []*RTPTransceiver{t} extensions, _ := rtpExtensionsFromMediaDescription(media) - mediaSections = append(mediaSections, mediaSection{id: midValue, transceivers: mediaTransceivers, matchExtensions: extensions, ridMap: getRids(media)}) + mediaSections = append(mediaSections, mediaSection{id: midValue, transceivers: mediaTransceivers, matchExtensions: extensions, rids: getRids(media)}) } } diff --git a/sdp.go b/sdp.go index 2b2a09b8009..18f05e884e8 100644 --- a/sdp.go +++ b/sdp.go @@ -202,8 +202,8 @@ func trackDetailsFromSDP(log logging.LeveledLogger, s *sdp.SessionDescription) ( id: trackID, rids: []string{}, } - for rid := range rids { - simulcastTrack.rids = append(simulcastTrack.rids, rid) + for _, rid := range rids { + simulcastTrack.rids = append(simulcastTrack.rids, rid.id) } tracksInMediaSection = []trackDetails{simulcastTrack} @@ -238,13 +238,13 @@ func trackDetailsToRTPReceiveParameters(t *trackDetails) RTPReceiveParameters { return RTPReceiveParameters{Encodings: encodings} } -func getRids(media *sdp.MediaDescription) map[string]*simulcastRid { - rids := map[string]*simulcastRid{} +func getRids(media *sdp.MediaDescription) []*simulcastRid { + rids := []*simulcastRid{} var simulcastAttr string for _, attr := range media.Attributes { if attr.Key == sdpAttributeRid { split := strings.Split(attr.Value, " ") - rids[split[0]] = &simulcastRid{attrValue: attr.Value} + rids = append(rids, &simulcastRid{id: split[0], attrValue: attr.Value}) } else if attr.Key == sdpAttributeSimulcast { simulcastAttr = attr.Value } @@ -257,9 +257,12 @@ func getRids(media *sdp.MediaDescription) map[string]*simulcastRid { ridStates := strings.Split(simulcastAttr, ";") for _, ridState := range ridStates { if ridState[:1] == "~" { - rid := ridState[1:] - if r, ok := rids[rid]; ok { - r.paused = true + ridID := ridState[1:] + for _, rid := range rids { + if rid.id == ridID { + rid.paused = true + break + } } } } @@ -499,15 +502,16 @@ func addTransceiverSDP( media.WithExtMap(sdp.ExtMap{Value: rtpExtension.ID, URI: extURL}) } - if len(mediaSection.ridMap) > 0 { - recvRids := make([]string, 0, len(mediaSection.ridMap)) + if len(mediaSection.rids) > 0 { + recvRids := make([]string, 0, len(mediaSection.rids)) - for rid := range mediaSection.ridMap { - media.WithValueAttribute(sdpAttributeRid, rid+" recv") - if mediaSection.ridMap[rid].paused { - rid = "~" + rid + for _, rid := range mediaSection.rids { + ridID := rid.id + media.WithValueAttribute(sdpAttributeRid, ridID+" recv") + if rid.paused { + ridID = "~" + ridID } - recvRids = append(recvRids, rid) + recvRids = append(recvRids, ridID) } // Simulcast media.WithValueAttribute(sdpAttributeSimulcast, "recv "+strings.Join(recvRids, ";")) @@ -533,6 +537,7 @@ func addTransceiverSDP( } type simulcastRid struct { + id string attrValue string paused bool } @@ -542,7 +547,7 @@ type mediaSection struct { transceivers []*RTPTransceiver data bool matchExtensions map[string]int - ridMap map[string]*simulcastRid + rids []*simulcastRid } func bundleMatchFromRemote(matchBundleGroup *string) func(mid string) bool { diff --git a/sdp_test.go b/sdp_test.go index d3f86307193..d5de76c381c 100644 --- a/sdp_test.go +++ b/sdp_test.go @@ -381,16 +381,18 @@ func TestPopulateSDP(t *testing.T) { tr := &RTPTransceiver{kind: RTPCodecTypeVideo, api: api, codecs: me.videoCodecs} tr.setDirection(RTPTransceiverDirectionRecvonly) - ridMap := map[string]*simulcastRid{ - "ridkey": { + rids := []*simulcastRid{ + { + id: "ridkey", attrValue: "some", }, - "ridPaused": { + { + id: "ridPaused", attrValue: "some2", paused: true, }, } - mediaSections := []mediaSection{{id: "video", transceivers: []*RTPTransceiver{tr}, ridMap: ridMap}} + mediaSections := []mediaSection{{id: "video", transceivers: []*RTPTransceiver{tr}, rids: rids}} d := &sdp.SessionDescription{} @@ -403,12 +405,14 @@ func TestPopulateSDP(t *testing.T) { if desc.MediaName.Media != "video" { continue } - ridInSDP := getRids(desc) - if ridKey, ok := ridInSDP["ridkey"]; ok && !ridKey.paused { - ridFound++ - } - if ridPaused, ok := ridInSDP["ridPaused"]; ok && ridPaused.paused { - ridFound++ + ridsInSDP := getRids(desc) + for _, rid := range ridsInSDP { + if rid.id == "ridkey" && !rid.paused { + ridFound++ + } + if rid.id == "ridPaused" && rid.paused { + ridFound++ + } } } assert.Equal(t, 2, ridFound, "All rid keys should be present") @@ -631,7 +635,14 @@ func TestGetRIDs(t *testing.T) { rids := getRids(m[0]) assert.NotEmpty(t, rids, "Rid mapping should be present") - if _, ok := rids["f"]; !ok { + found := false + for _, rid := range rids { + if rid.id == "f" { + found = true + break + } + } + if !found { assert.Fail(t, "rid values should contain 'f'") } } From 4e4a67d7a2b29f4842eb61908fa1046f9756e005 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Tue, 6 Aug 2024 09:47:49 -0400 Subject: [PATCH 4/4] Add PeerConnection.GracefulClose --- datachannel.go | 62 +++++++++++++++++++++++++++++++ go.mod | 4 +- go.sum | 8 ++-- icegatherer.go | 22 ++++++++++- icetransport.go | 24 +++++++++++- internal/mux/mux.go | 4 ++ operations.go | 71 +++++++++++++++++++++++++++++------- operations_test.go | 32 ++++++++++++++++ peerconnection.go | 45 ++++++++++++++++++++++- peerconnection_close_test.go | 66 +++++++++++++++++++++++++++++++++ 10 files changed, 314 insertions(+), 24 deletions(-) diff --git a/datachannel.go b/datachannel.go index f7c9511b350..c3ce10b9e32 100644 --- a/datachannel.go +++ b/datachannel.go @@ -40,6 +40,8 @@ type DataChannel struct { readyState atomic.Value // DataChannelState bufferedAmountLowThreshold uint64 detachCalled bool + readLoopActive chan struct{} + isGracefulClosed bool // The binaryType represents attribute MUST, on getting, return the value to // which it was last set. On setting, if the new value is either the string @@ -225,6 +227,10 @@ func (d *DataChannel) OnOpen(f func()) { func (d *DataChannel) onOpen() { d.mu.RLock() handler := d.onOpenHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler != nil { @@ -252,6 +258,10 @@ func (d *DataChannel) OnDial(f func()) { func (d *DataChannel) onDial() { d.mu.RLock() handler := d.onDialHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler != nil { @@ -261,6 +271,10 @@ func (d *DataChannel) onDial() { // OnClose sets an event handler which is invoked when // the underlying data transport has been closed. +// Note: Due to backwards compatibility, there is a chance that +// OnClose can be called, even if the GracefulClose is used. +// If this is the case for you, you can deregister OnClose +// prior to GracefulClose. func (d *DataChannel) OnClose(f func()) { d.mu.Lock() defer d.mu.Unlock() @@ -292,6 +306,10 @@ func (d *DataChannel) OnMessage(f func(msg DataChannelMessage)) { func (d *DataChannel) onMessage(msg DataChannelMessage) { d.mu.RLock() handler := d.onMessageHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler == nil { @@ -302,6 +320,10 @@ func (d *DataChannel) onMessage(msg DataChannelMessage) { func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlreadyNegotiated bool) { d.mu.Lock() + if d.isGracefulClosed { + d.mu.Unlock() + return + } d.dataChannel = dc bufferedAmountLowThreshold := d.bufferedAmountLowThreshold onBufferedAmountLow := d.onBufferedAmountLow @@ -326,7 +348,12 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread d.mu.Lock() defer d.mu.Unlock() + if d.isGracefulClosed { + return + } + if !d.api.settingEngine.detach.DataChannels { + d.readLoopActive = make(chan struct{}) go d.readLoop() } } @@ -342,6 +369,10 @@ func (d *DataChannel) OnError(f func(err error)) { func (d *DataChannel) onError(err error) { d.mu.RLock() handler := d.onErrorHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler != nil { @@ -356,6 +387,12 @@ var rlBufPool = sync.Pool{New: func() interface{} { }} func (d *DataChannel) readLoop() { + defer func() { + d.mu.Lock() + readLoopActive := d.readLoopActive + d.mu.Unlock() + defer close(readLoopActive) + }() for { buffer := rlBufPool.Get().([]byte) //nolint:forcetypeassert n, isString, err := d.dataChannel.ReadDataChannel(buffer) @@ -438,7 +475,32 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { // Close Closes the DataChannel. It may be called regardless of whether // the DataChannel object was created by this peer or the remote peer. func (d *DataChannel) Close() error { + return d.close(false) +} + +// GracefulClose Closes the DataChannel. It may be called regardless of whether +// the DataChannel object was created by this peer or the remote peer. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// DataChannel callbacks or if in a callback, in its own goroutine. +func (d *DataChannel) GracefulClose() error { + return d.close(true) +} + +// Normally, close only stops writes from happening, so graceful=true +// will wait for reads to be finished based on underlying SCTP association +// closure or a SCTP reset stream from the other side. This is safe to call +// with graceful=true after tearing down a PeerConnection but not +// necessarily before. For example, if you used a vnet and dropped all packets +// right before closing the DataChannel, you'd need never see a reset stream. +func (d *DataChannel) close(shouldGracefullyClose bool) error { d.mu.Lock() + d.isGracefulClosed = true + readLoopActive := d.readLoopActive + if shouldGracefullyClose && readLoopActive != nil { + defer func() { + <-readLoopActive + }() + } haveSctpTransport := d.dataChannel != nil d.mu.Unlock() diff --git a/go.mod b/go.mod index e4687e03aed..91dc415a7e2 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.17 require ( github.com/pion/datachannel v1.5.8 github.com/pion/dtls/v2 v2.2.12 - github.com/pion/ice/v2 v2.3.31 + github.com/pion/ice/v2 v2.3.34 github.com/pion/interceptor v0.1.29 github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 @@ -15,7 +15,7 @@ require ( github.com/pion/sdp/v3 v3.0.9 github.com/pion/srtp/v2 v2.0.20 github.com/pion/stun v0.6.1 - github.com/pion/transport/v2 v2.2.8 + github.com/pion/transport/v2 v2.2.10 github.com/sclevine/agouti v3.0.0+incompatible github.com/stretchr/testify v1.9.0 golang.org/x/net v0.22.0 diff --git a/go.sum b/go.sum index 22f23449e97..71271951dbf 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,8 @@ github.com/pion/datachannel v1.5.8/go.mod h1:PgmdpoaNBLX9HNzNClmdki4DYW5JtI7Yibu github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/ice/v2 v2.3.31 h1:qag/YqiOn5qPi0kgeVdsytxjx8szuriWSIeXKu8dDQc= -github.com/pion/ice/v2 v2.3.31/go.mod h1:8fac0+qftclGy1tYd/nfwfHC729BLaxtVqMdMVCAVPU= +github.com/pion/ice/v2 v2.3.34 h1:Ic1ppYCj4tUOcPAp76U6F3fVrlSw8A9JtRXLqw6BbUM= +github.com/pion/ice/v2 v2.3.34/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ= github.com/pion/interceptor v0.1.29 h1:39fsnlP1U8gw2JzOFWdfCU82vHvhW9o0rZnZF56wF+M= github.com/pion/interceptor v0.1.29/go.mod h1:ri+LGNjRUc5xUNtDEPzfdkmSqISixVTBF/z/Zms/6T4= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= @@ -74,8 +74,8 @@ github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/ github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.3/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= -github.com/pion/transport/v2 v2.2.8 h1:HzsqGBChgtF4Cj47gu51l5hONuK/NwgbZL17CMSuwS0= -github.com/pion/transport/v2 v2.2.8/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E= +github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q= +github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= github.com/pion/transport/v3 v3.0.2 h1:r+40RJR25S9w3jbA6/5uEPTzcdn7ncyU44RWCbHkLg4= github.com/pion/transport/v3 v3.0.2/go.mod h1:nIToODoOlb5If2jF9y2Igfx3PFYWfuXi37m0IlWa/D0= diff --git a/icegatherer.go b/icegatherer.go index d01ecc1344b..cd0d8672b82 100644 --- a/icegatherer.go +++ b/icegatherer.go @@ -188,13 +188,31 @@ func (g *ICEGatherer) Gather() error { // Close prunes all local candidates, and closes the ports. func (g *ICEGatherer) Close() error { + return g.close(false /* shouldGracefullyClose */) +} + +// GracefulClose prunes all local candidates, and closes the ports. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// ICEGatherer callbacks or if in a callback, in its own goroutine. +func (g *ICEGatherer) GracefulClose() error { + return g.close(true /* shouldGracefullyClose */) +} + +func (g *ICEGatherer) close(shouldGracefullyClose bool) error { g.lock.Lock() defer g.lock.Unlock() if g.agent == nil { return nil - } else if err := g.agent.Close(); err != nil { - return err + } + if shouldGracefullyClose { + if err := g.agent.GracefulClose(); err != nil { + return err + } + } else { + if err := g.agent.Close(); err != nil { + return err + } } g.agent = nil diff --git a/icetransport.go b/icetransport.go index 469aafbd43f..cb9aa22de88 100644 --- a/icetransport.go +++ b/icetransport.go @@ -16,6 +16,7 @@ import ( "github.com/pion/ice/v2" "github.com/pion/logging" "github.com/pion/webrtc/v3/internal/mux" + "github.com/pion/webrtc/v3/internal/util" ) // ICETransport allows an application access to information about the ICE @@ -187,6 +188,17 @@ func (t *ICETransport) restart() error { // Stop irreversibly stops the ICETransport. func (t *ICETransport) Stop() error { + return t.stop(false /* shouldGracefullyClose */) +} + +// GracefulStop irreversibly stops the ICETransport. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// ICETransport callbacks or if in a callback, in its own goroutine. +func (t *ICETransport) GracefulStop() error { + return t.stop(true /* shouldGracefullyClose */) +} + +func (t *ICETransport) stop(shouldGracefullyClose bool) error { t.lock.Lock() defer t.lock.Unlock() @@ -197,8 +209,18 @@ func (t *ICETransport) Stop() error { } if t.mux != nil { - return t.mux.Close() + var closeErrs []error + if shouldGracefullyClose && t.gatherer != nil { + // we can't access icegatherer/icetransport.Close via + // mux's net.Conn Close so we call it earlier here. + closeErrs = append(closeErrs, t.gatherer.GracefulClose()) + } + closeErrs = append(closeErrs, t.mux.Close()) + return util.FlattenErrs(closeErrs) } else if t.gatherer != nil { + if shouldGracefullyClose { + return t.gatherer.GracefulClose() + } return t.gatherer.Close() } return nil diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 1e167b89784..d57b9c3cf3b 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -120,6 +120,10 @@ func (m *Mux) readLoop() { } if err = m.dispatch(buf[:n]); err != nil { + if errors.Is(err, io.ErrClosedPipe) { + // if the buffer was closed, that's not an error we care to report + return + } m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error()) return } diff --git a/operations.go b/operations.go index bc366ac34db..67d24eebecc 100644 --- a/operations.go +++ b/operations.go @@ -13,12 +13,13 @@ type operation func() // Operations is a task executor. type operations struct { - mu sync.Mutex - busy bool - ops *list.List + mu sync.Mutex + busyCh chan struct{} + ops *list.List updateNegotiationNeededFlagOnEmptyChain *atomicBool onNegotiationNeeded func() + isClosed bool } func newOperations( @@ -33,21 +34,34 @@ func newOperations( } // Enqueue adds a new action to be executed. If there are no actions scheduled, -// the execution will start immediately in a new goroutine. +// the execution will start immediately in a new goroutine. If the queue has been +// closed, the operation will be dropped. The queue is only deliberately closed +// by a user. func (o *operations) Enqueue(op operation) { + o.mu.Lock() + defer o.mu.Unlock() + _ = o.tryEnqueue(op) +} + +// tryEnqueue attempts to enqueue the given operation. It returns false +// if the op is invalid or the queue is closed. mu must be locked by +// tryEnqueue's caller. +func (o *operations) tryEnqueue(op operation) bool { if op == nil { - return + return false } - o.mu.Lock() - running := o.busy + if o.isClosed { + return false + } o.ops.PushBack(op) - o.busy = true - o.mu.Unlock() - if !running { + if o.busyCh == nil { + o.busyCh = make(chan struct{}) go o.start() } + + return true } // IsEmpty checks if there are tasks in the queue @@ -62,12 +76,38 @@ func (o *operations) IsEmpty() bool { func (o *operations) Done() { var wg sync.WaitGroup wg.Add(1) - o.Enqueue(func() { + o.mu.Lock() + enqueued := o.tryEnqueue(func() { wg.Done() }) + o.mu.Unlock() + if !enqueued { + return + } wg.Wait() } +// GracefulClose waits for the operations queue to be cleared and forbids +// new operations from being enqueued. +func (o *operations) GracefulClose() { + o.mu.Lock() + if o.isClosed { + o.mu.Unlock() + return + } + // do not enqueue anymore ops from here on + // o.isClosed=true will also not allow a new busyCh + // to be created. + o.isClosed = true + + busyCh := o.busyCh + o.mu.Unlock() + if busyCh == nil { + return + } + <-busyCh +} + func (o *operations) pop() func() { o.mu.Lock() defer o.mu.Unlock() @@ -87,12 +127,17 @@ func (o *operations) start() { defer func() { o.mu.Lock() defer o.mu.Unlock() - if o.ops.Len() == 0 { - o.busy = false + // this wil lbe the most recent busy chan + close(o.busyCh) + + if o.ops.Len() == 0 || o.isClosed { + o.busyCh = nil return } + // either a new operation was enqueued while we // were busy, or an operation panicked + o.busyCh = make(chan struct{}) go o.start() }() diff --git a/operations_test.go b/operations_test.go index 428c2b4df97..3b84a1def5b 100644 --- a/operations_test.go +++ b/operations_test.go @@ -19,6 +19,8 @@ func TestOperations_Enqueue(t *testing.T) { onNegotiationNeededCalledCount++ onNegotiationNeededCalledCountMu.Unlock() }) + defer ops.GracefulClose() + for resultSet := 0; resultSet < 100; resultSet++ { results := make([]int, 16) resultSetCopy := resultSet @@ -46,5 +48,35 @@ func TestOperations_Enqueue(t *testing.T) { func TestOperations_Done(*testing.T) { ops := newOperations(&atomicBool{}, func() { }) + defer ops.GracefulClose() + ops.Done() +} + +func TestOperations_GracefulClose(t *testing.T) { + ops := newOperations(&atomicBool{}, func() { + }) + + counter := 0 + var counterMu sync.Mutex + incFunc := func() { + counterMu.Lock() + counter++ + counterMu.Unlock() + } + const times = 25 + for i := 0; i < times; i++ { + ops.Enqueue(incFunc) + } + ops.Done() + counterMu.Lock() + counterCur := counter + counterMu.Unlock() + assert.Equal(t, counterCur, times) + + ops.GracefulClose() + for i := 0; i < times; i++ { + ops.Enqueue(incFunc) + } ops.Done() + assert.Equal(t, counterCur, times) } diff --git a/peerconnection.go b/peerconnection.go index d00f3baf19c..3b40dafecb6 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -56,6 +56,8 @@ type PeerConnection struct { idpLoginURL *string isClosed *atomicBool + isGracefulClosed *atomicBool + isGracefulClosedDone chan struct{} isNegotiationNeeded *atomicBool updateNegotiationNeededFlagOnEmptyChain *atomicBool @@ -128,6 +130,8 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, ICECandidatePoolSize: 0, }, isClosed: &atomicBool{}, + isGracefulClosed: &atomicBool{}, + isGracefulClosedDone: make(chan struct{}), isNegotiationNeeded: &atomicBool{}, updateNegotiationNeededFlagOnEmptyChain: &atomicBool{}, lastOffer: "", @@ -2082,13 +2086,34 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes return pc.dtlsTransport.WriteRTCP(pkts) } -// Close ends the PeerConnection +// Close ends the PeerConnection. func (pc *PeerConnection) Close() error { + return pc.close(false /* shouldGracefullyClose */) +} + +// GracefulClose ends the PeerConnection. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// PeerConnection callbacks or if in a callback, in its own goroutine. +func (pc *PeerConnection) GracefulClose() error { + return pc.close(true /* shouldGracefullyClose */) +} + +func (pc *PeerConnection) close(shouldGracefullyClose bool) error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) + alreadyGracefullyClosed := shouldGracefullyClose && pc.isGracefulClosed.swap(true) if pc.isClosed.swap(true) { + if alreadyGracefullyClosed { + // similar but distinct condition where we may be waiting for some + // other graceful close to finish. Incorrectly using isClosed may + // leak a goroutine. + <-pc.isGracefulClosedDone + } return nil } + if shouldGracefullyClose && !alreadyGracefullyClosed { + defer close(pc.isGracefulClosedDone) + } // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3) pc.signalingState.Set(SignalingStateClosed) @@ -2132,12 +2157,28 @@ func (pc *PeerConnection) Close() error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10) if pc.iceTransport != nil { - closeErrs = append(closeErrs, pc.iceTransport.Stop()) + if shouldGracefullyClose { + // note that it isn't canon to stop gracefully + closeErrs = append(closeErrs, pc.iceTransport.GracefulStop()) + } else { + closeErrs = append(closeErrs, pc.iceTransport.Stop()) + } } // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + if shouldGracefullyClose { + pc.ops.GracefulClose() + + // note that it isn't canon to stop gracefully + pc.sctpTransport.lock.Lock() + for _, d := range pc.sctpTransport.dataChannels { + closeErrs = append(closeErrs, d.GracefulClose()) + } + pc.sctpTransport.lock.Unlock() + } + return util.FlattenErrs(closeErrs) } diff --git a/peerconnection_close_test.go b/peerconnection_close_test.go index df7a6526b1e..9f90034a289 100644 --- a/peerconnection_close_test.go +++ b/peerconnection_close_test.go @@ -179,3 +179,69 @@ func TestPeerConnection_Close_DuringICE(t *testing.T) { t.Error("pcOffer.Close() Timeout") } } + +func TestPeerConnection_CloseWithIncomingMessages(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + report := test.CheckRoutinesStrict(t) + defer report() + + pcOffer, pcAnswer, err := newPair() + if err != nil { + t.Fatal(err) + } + + var dcAnswer *DataChannel + answerDataChannelOpened := make(chan struct{}) + pcAnswer.OnDataChannel(func(d *DataChannel) { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.Label() != "data" { + return + } + dcAnswer = d + close(answerDataChannelOpened) + }) + + dcOffer, err := pcOffer.CreateDataChannel("data", nil) + if err != nil { + t.Fatal(err) + } + + offerDataChannelOpened := make(chan struct{}) + dcOffer.OnOpen(func() { + close(offerDataChannelOpened) + }) + + err = signalPair(pcOffer, pcAnswer) + if err != nil { + t.Fatal(err) + } + + <-offerDataChannelOpened + <-answerDataChannelOpened + + msgNum := 0 + dcOffer.OnMessage(func(_ DataChannelMessage) { + t.Log("msg", msgNum) + msgNum++ + }) + + // send 50 messages, then close pcOffer, and then send another 50 + for i := 0; i < 100; i++ { + if i == 50 { + err = pcOffer.GracefulClose() + if err != nil { + t.Fatal(err) + } + } + _ = dcAnswer.Send([]byte("hello!")) + } + + err = pcAnswer.GracefulClose() + if err != nil { + t.Fatal(err) + } +}