diff --git a/dtlstransport.go b/dtlstransport.go index ec08a0846d3..bfc38d2b724 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -51,7 +51,7 @@ type DTLSTransport struct { srtpSession, srtcpSession atomic.Value srtpEndpoint, srtcpEndpoint *mux.Endpoint - simulcastStreams []*srtp.ReadStreamSRTP + simulcastStreams []simulcastStreamPair srtpReady chan struct{} dtlsMatcher mux.MatchFunc @@ -60,6 +60,11 @@ type DTLSTransport struct { log logging.LeveledLogger } +type simulcastStreamPair struct { + srtp *srtp.ReadStreamSRTP + srtcp *srtp.ReadStreamSRTCP +} + // NewDTLSTransport creates a new DTLSTransport. // This constructor is part of the ORTC API. It is not // meant to be used together with the basic WebRTC API. @@ -436,7 +441,8 @@ func (t *DTLSTransport) Stop() error { } for i := range t.simulcastStreams { - closeErrs = append(closeErrs, t.simulcastStreams[i].Close()) + closeErrs = append(closeErrs, t.simulcastStreams[i].srtp.Close()) + closeErrs = append(closeErrs, t.simulcastStreams[i].srtcp.Close()) } if t.conn != nil { @@ -477,11 +483,11 @@ func (t *DTLSTransport) ensureICEConn() error { return nil } -func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) { +func (t *DTLSTransport) storeSimulcastStream(srtpReadStream *srtp.ReadStreamSRTP, srtcpReadStream *srtp.ReadStreamSRTCP) { t.lock.Lock() defer t.lock.Unlock() - t.simulcastStreams = append(t.simulcastStreams, s) + t.simulcastStreams = append(t.simulcastStreams, simulcastStreamPair{srtpReadStream, srtcpReadStream}) } func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { diff --git a/peerconnection.go b/peerconnection.go index 5876e5f76cd..4bbe1a19945 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1472,11 +1472,23 @@ func (pc *PeerConnection) startSCTP() { } func (pc *PeerConnection) handleUndeclaredSSRC(ssrc SSRC, remoteDescription *SessionDescription) (handled bool, err error) { - if len(remoteDescription.parsed.MediaDescriptions) != 1 { + mediaIdx := -1 + // find first and only audio/video media description; otherwise fail. + // DataChannels do not count. + for idx, mediaDesc := range remoteDescription.parsed.MediaDescriptions { + switch mediaDesc.MediaName.Media { + case RTPCodecTypeVideo.String(), RTPCodecTypeAudio.String(): + if mediaIdx != -1 { // more than one media + return false, nil + } + mediaIdx = idx + } + } + if mediaIdx == -1 { return false, nil } - onlyMediaSection := remoteDescription.parsed.MediaDescriptions[0] + onlyMediaSection := remoteDescription.parsed.MediaDescriptions[mediaIdx] streamID := "" id := "" hasRidAttribute := false @@ -1569,7 +1581,8 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err } } - // If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared + // If the remote SDP was only one media (non-datachannel) section the ssrc doesn't + // have to be explicitly declared if handled, err := pc.handleUndeclaredSSRC(ssrc, remoteDescription); handled || err != nil { return err } @@ -1670,26 +1683,42 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { return } - stream, ssrc, err := srtpSession.AcceptStream() + srtcpSession, err := pc.dtlsTransport.getSRTCPSession() + if err != nil { + pc.log.Warnf("undeclaredMediaProcessor failed to open SrtcpSession: %v", err) + return + } + + srtpReadStream, ssrc, err := srtpSession.AcceptStream() if err != nil { pc.log.Warnf("Failed to accept RTP %v", err) return } + // open accompanying srtcp stream + srtcpReadStream, err := srtcpSession.OpenReadStream(ssrc) + if err != nil { + pc.log.Warnf("Failed to open RTCP stream for %d: %v", ssrc, err) + return + } + if pc.isClosed.get() { - if err = stream.Close(); err != nil { + if err = srtpReadStream.Close(); err != nil { pc.log.Warnf("Failed to close RTP stream %v", err) } + if err = srtcpReadStream.Close(); err != nil { + pc.log.Warnf("Failed to close RTCP stream %v", err) + } continue } + pc.dtlsTransport.storeSimulcastStream(srtpReadStream, srtcpReadStream) + if ssrc == 0 { go pc.handleNonMediaBandwidthProbe() continue } - pc.dtlsTransport.storeSimulcastStream(stream) - if atomic.AddUint64(&simulcastRoutineCount, 1) >= simulcastMaxProbeRoutines { atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) pc.log.Warn(ErrSimulcastProbeOverflow.Error()) @@ -1701,7 +1730,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err) } atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) - }(stream, SSRC(ssrc)) + }(srtpReadStream, SSRC(ssrc)) } }