From 41d31c3e9e8576bfab8660df6a4505e48faa86a3 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 9 Feb 2023 23:55:13 +0100 Subject: [PATCH 1/7] publisher: introduce a publisher A new generic publisher. It can't get more simple than that. --- pkg/conference/publisher/publisher.go | 113 ++++++++++++++++++++++++++ pkg/conference/publisher/track.go | 18 ++++ 2 files changed, 131 insertions(+) create mode 100644 pkg/conference/publisher/publisher.go create mode 100644 pkg/conference/publisher/track.go diff --git a/pkg/conference/publisher/publisher.go b/pkg/conference/publisher/publisher.go new file mode 100644 index 0000000..b2d637a --- /dev/null +++ b/pkg/conference/publisher/publisher.go @@ -0,0 +1,113 @@ +package publisher + +import ( + "errors" + "fmt" + "sync" + + "github.com/pion/rtp" +) + +var ErrSubscriptionExists = errors.New("subscription already exists") + +type Subscription interface { + // WriteRTP **must not** block (wait on I/O). + WriteRTP(packet rtp.Packet) error +} + +type Track interface { + // ReadPacket **may** block (wait on I/O). + ReadPacket() (*rtp.Packet, error) +} + +// An abstract publisher that reads the packets from the track and forwards them to all subscribers. +type Publisher struct { + mu sync.Mutex + track Track + subscriptions map[Subscription]struct{} +} + +func NewPublisher( + track Track, + stop <-chan struct{}, +) (*Publisher, <-chan struct{}) { + // Create a done channel, so that we can signal the caller when we're done. + done := make(chan struct{}) + + publisher := &Publisher{ + track: track, + subscriptions: make(map[Subscription]struct{}), + } + + // Start a goroutine that will read RTP packets from the remote track. + // We run the publisher until we receive a stop signal or an error occurs. + go func() { + defer close(done) + for { + // Check if we were signaled to stop. + select { + case <-stop: + return + default: + if err := publisher.forwardPacket(); err != nil { + fmt.Println("failed to write to subscribers: ", err) + return + } + } + } + }() + + return publisher, done +} + +func (p *Publisher) AddSubscription(subscription Subscription) { + p.mu.Lock() + defer p.mu.Unlock() + + if _, ok := p.subscriptions[subscription]; ok { + return + } + + p.subscriptions[subscription] = struct{}{} +} + +func (p *Publisher) RemoveSubscription(subscription Subscription) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.subscriptions, subscription) +} + +func (p *Publisher) GetTrack() Track { + p.mu.Lock() + defer p.mu.Unlock() + return p.track +} + +func (p *Publisher) ReplaceTrack(track Track) { + p.mu.Lock() + defer p.mu.Unlock() + p.track = track +} + +// Reads a single packet from the remote track and forwards it to all subscribers. +func (p *Publisher) forwardPacket() error { + track := p.GetTrack() + + packet, err := track.ReadPacket() + if err != nil { + return err + } + + p.mu.Lock() + defer p.mu.Unlock() + + // Write the packet to all subscribers. + for subscription := range p.subscriptions { + if err := subscription.WriteRTP(*packet); err != nil { + fmt.Println("failed to write to subscriber: ", err) + delete(p.subscriptions, subscription) + } + } + + return nil +} diff --git a/pkg/conference/publisher/track.go b/pkg/conference/publisher/track.go new file mode 100644 index 0000000..658da43 --- /dev/null +++ b/pkg/conference/publisher/track.go @@ -0,0 +1,18 @@ +package publisher + +import ( + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" +) + +// Wrapper for the `webrtc.TrackRemote`. +type RemoteTrack struct { + // The underlying `webrtc.TrackRemote`. + Track *webrtc.TrackRemote +} + +// Implement the `Track` interface for the `webrtc.TrackRemote`. +func (t *RemoteTrack) ReadPacket() (*rtp.Packet, error) { + packet, _, err := t.Track.ReadRTP() + return packet, err +} From bb49d2ef6493f6bcb9cc6a270ac657ba4b39b295 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 10 Feb 2023 00:14:46 +0100 Subject: [PATCH 2/7] conference: no routing for RTP over conference --- pkg/conference/participant/tracker.go | 15 --------------- pkg/conference/peer_message_processing.go | 4 ---- pkg/conference/processing.go | 2 -- pkg/peer/messages.go | 7 ------- pkg/peer/remote_track.go | 9 +++++---- 5 files changed, 5 insertions(+), 32 deletions(-) diff --git a/pkg/conference/participant/tracker.go b/pkg/conference/participant/tracker.go index efb550a..555842d 100644 --- a/pkg/conference/participant/tracker.go +++ b/pkg/conference/participant/tracker.go @@ -5,9 +5,7 @@ import ( "github.com/matrix-org/waterfall/pkg/conference/subscription" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/rtp" "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" "golang.org/x/exp/slices" ) @@ -220,16 +218,3 @@ func (t *Tracker) Unsubscribe(participantID ID, trackID TrackID) { } } } - -// Processes an RTP packet received on a given track. -func (t *Tracker) ProcessRTP(info webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer, packet *rtp.Packet) { - if published := t.publishedTracks[info.TrackID]; published != nil { - for _, sub := range published.Subscriptions { - if sub.Simulcast() == simulcast { - if err := sub.WriteRTP(*packet); err != nil { - logrus.Errorf("Dropping an RTP packet on %s (%s): %s", info.TrackID, simulcast, err) - } - } - } - } -} diff --git a/pkg/conference/peer_message_processing.go b/pkg/conference/peer_message_processing.go index 264ee43..72494bd 100644 --- a/pkg/conference/peer_message_processing.go +++ b/pkg/conference/peer_message_processing.go @@ -33,10 +33,6 @@ func (c *Conference) processNewTrackPublishedMessage(sender participant.ID, msg c.resendMetadataToAllExcept(sender) } -func (c *Conference) processRTPPacketReceivedMessage(msg peer.RTPPacketReceived) { - c.tracker.ProcessRTP(msg.TrackInfo, msg.SimulcastLayer, msg.Packet) -} - func (c *Conference) processPublishedTrackFailedMessage(sender participant.ID, msg peer.PublishedTrackFailed) { c.newLogger(sender).Infof("Failed published track: %s", msg.TrackID) c.tracker.RemovePublishedTrack(msg.TrackID) diff --git a/pkg/conference/processing.go b/pkg/conference/processing.go index eecfc26..a8da1f9 100644 --- a/pkg/conference/processing.go +++ b/pkg/conference/processing.go @@ -42,8 +42,6 @@ func (c *Conference) processPeerMessage(message channel.Message[participant.ID, c.processLeftTheCallMessage(message.Sender, msg) case peer.NewTrackPublished: c.processNewTrackPublishedMessage(message.Sender, msg) - case peer.RTPPacketReceived: - c.processRTPPacketReceivedMessage(msg) case peer.PublishedTrackFailed: c.processPublishedTrackFailedMessage(message.Sender, msg) case peer.NewICECandidate: diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 4f67a6d..83037cc 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -2,7 +2,6 @@ package peer import ( "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/rtp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -33,12 +32,6 @@ type PublishedTrackFailed struct { SimulcastLayer webrtc_ext.SimulcastLayer } -type RTPPacketReceived struct { - webrtc_ext.TrackInfo - SimulcastLayer webrtc_ext.SimulcastLayer - Packet *rtp.Packet -} - type NewICECandidate struct { Candidate *webrtc.ICECandidate } diff --git a/pkg/peer/remote_track.go b/pkg/peer/remote_track.go index 97a82ae..a6e9ade 100644 --- a/pkg/peer/remote_track.go +++ b/pkg/peer/remote_track.go @@ -16,10 +16,7 @@ func (p *Peer[ID]) handleNewVideoTrack( ) { simulcast := webrtc_ext.RIDToSimulcastLayer(remoteTrack.RID()) - p.handleRemoteTrack(remoteTrack, trackInfo, simulcast, nil, func(packet *rtp.Packet) error { - p.sink.Send(RTPPacketReceived{trackInfo, simulcast, packet}) - return nil - }) + p.handleRemoteTrack(remoteTrack, trackInfo, simulcast, nil, nil) } func (p *Peer[ID]) handleNewAudioTrack( @@ -66,6 +63,10 @@ func (p *Peer[ID]) handleRemoteTrack( p.sink.Send(PublishedTrackFailed{trackInfo, simulcast}) }() + if handleRtpFn == nil { + return + } + for { // Read the data from the remote track. packet, _, readErr := remoteTrack.ReadRTP() From 4dc344411486e5b2b78d0a6257818d9c0a84305a Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 13 Feb 2023 19:50:15 +0100 Subject: [PATCH 3/7] peer: don't handle tracks inside peer anymore Now, the peer will just send the remote track to the conference and let the conference to create publishers and start go-routines for processing publishers and managing the lifetime of the tracks. --- pkg/peer/messages.go | 16 +------ pkg/peer/peer.go | 10 +--- pkg/peer/remote_track.go | 89 ------------------------------------ pkg/peer/state/peer_state.go | 36 ++------------- pkg/peer/webrtc_callbacks.go | 12 +---- 5 files changed, 8 insertions(+), 155 deletions(-) delete mode 100644 pkg/peer/remote_track.go diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 83037cc..896c01b 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,7 +1,6 @@ package peer import ( - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -17,19 +16,8 @@ type LeftTheCall struct { } type NewTrackPublished struct { - // Information about the track (ID etc). - webrtc_ext.TrackInfo - // SimulcastLayer configuration (can be `None` for non-simulcast tracks and for audio tracks). - SimulcastLayer webrtc_ext.SimulcastLayer - // Output track (if any) that could be used to send data to the peer. Will be `nil` if such - // track does not exist, in which case the caller is expected to listen to `RtpPacketReceived` - // messages. - OutputTrack *webrtc.TrackLocalStaticRTP -} - -type PublishedTrackFailed struct { - webrtc_ext.TrackInfo - SimulcastLayer webrtc_ext.SimulcastLayer + // Remote track that has been published. + RemoteTrack *webrtc.TrackRemote } type NewICECandidate struct { diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 0b20650..7211acf 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -21,7 +21,6 @@ var ( ErrDataChannelNotAvailable = errors.New("data channel is not available") ErrDataChannelNotReady = errors.New("data channel is not ready") ErrCantSubscribeToTrack = errors.New("can't subscribe to track") - ErrTrackNotFound = errors.New("track not found") ) // A wrapped representation of the peer connection (single peer in the call). @@ -84,14 +83,7 @@ func (p *Peer[ID]) Terminate() { } // Request a key frame from the peer connection. -func (p *Peer[ID]) RequestKeyFrame(info webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error { - // Find the right track. - track := p.state.GetRemoteTrack(info.TrackID, simulcast) - if track == nil { - return ErrTrackNotFound - } - - p.logger.Debugf("Keyframe request: %s (%s)", info.TrackID, simulcast) +func (p *Peer[ID]) RequestKeyFrame(track *webrtc.TrackRemote) error { rtcps := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}} return p.peerConnection.WriteRTCP(rtcps) } diff --git a/pkg/peer/remote_track.go b/pkg/peer/remote_track.go deleted file mode 100644 index a6e9ade..0000000 --- a/pkg/peer/remote_track.go +++ /dev/null @@ -1,89 +0,0 @@ -package peer - -import ( - "errors" - "io" - - "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/rtp" - "github.com/pion/webrtc/v3" -) - -func (p *Peer[ID]) handleNewVideoTrack( - trackInfo webrtc_ext.TrackInfo, - remoteTrack *webrtc.TrackRemote, - receiver *webrtc.RTPReceiver, -) { - simulcast := webrtc_ext.RIDToSimulcastLayer(remoteTrack.RID()) - - p.handleRemoteTrack(remoteTrack, trackInfo, simulcast, nil, nil) -} - -func (p *Peer[ID]) handleNewAudioTrack( - trackInfo webrtc_ext.TrackInfo, - remoteTrack *webrtc.TrackRemote, - receiver *webrtc.RTPReceiver, -) { - // Create a local track, all our SFU clients that are subscribed to this - // peer (publisher) wil be fed via this track. - localTrack, err := webrtc.NewTrackLocalStaticRTP( - remoteTrack.Codec().RTPCodecCapability, - remoteTrack.ID(), - remoteTrack.StreamID(), - ) - if err != nil { - p.logger.WithError(err).Error("failed to create local track") - return - } - - p.handleRemoteTrack(remoteTrack, trackInfo, webrtc_ext.SimulcastLayerNone, localTrack, func(packet *rtp.Packet) error { - if err = localTrack.WriteRTP(packet); err != nil && !errors.Is(err, io.ErrClosedPipe) { - return err - } - return nil - }) -} - -func (p *Peer[ID]) handleRemoteTrack( - remoteTrack *webrtc.TrackRemote, - trackInfo webrtc_ext.TrackInfo, - simulcast webrtc_ext.SimulcastLayer, - outputTrack *webrtc.TrackLocalStaticRTP, - handleRtpFn func(*rtp.Packet) error, -) { - // Notify others that our track has just been published. - p.state.AddRemoteTrack(remoteTrack) - p.sink.Send(NewTrackPublished{trackInfo, simulcast, outputTrack}) - - // Start a go-routine that reads the data from the remote track. - go func() { - // Call this when this goroutine ends. - defer func() { - p.state.RemoveRemoteTrack(remoteTrack) - p.sink.Send(PublishedTrackFailed{trackInfo, simulcast}) - }() - - if handleRtpFn == nil { - return - } - - for { - // Read the data from the remote track. - packet, _, readErr := remoteTrack.ReadRTP() - if readErr != nil { - if readErr == io.EOF { // finished, no more data, no error, inform others - p.logger.Info("remote track closed") - } else { // finished, no more data, but with error, inform others - p.logger.WithError(readErr).Error("failed to read from remote track") - } - return - } - - // Handle the RTP packet. - if err := handleRtpFn(packet); err != nil { - p.logger.WithError(err).Error("failed to handle RTP packet") - return - } - } - }() -} diff --git a/pkg/peer/state/peer_state.go b/pkg/peer/state/peer_state.go index b277139..819dd8d 100644 --- a/pkg/peer/state/peer_state.go +++ b/pkg/peer/state/peer_state.go @@ -3,46 +3,16 @@ package state import ( "sync" - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" ) -type RemoteTrackId struct { - id string - simulcast webrtc_ext.SimulcastLayer -} - type PeerState struct { - mutex sync.Mutex - dataChannel *webrtc.DataChannel - remoteTracks map[RemoteTrackId]*webrtc.TrackRemote + mutex sync.Mutex + dataChannel *webrtc.DataChannel } func NewPeerState() *PeerState { - return &PeerState{ - remoteTracks: make(map[RemoteTrackId]*webrtc.TrackRemote), - } -} - -func (p *PeerState) AddRemoteTrack(track *webrtc.TrackRemote) { - p.mutex.Lock() - defer p.mutex.Unlock() - - p.remoteTracks[RemoteTrackId{track.ID(), webrtc_ext.RIDToSimulcastLayer(track.RID())}] = track -} - -func (p *PeerState) RemoveRemoteTrack(track *webrtc.TrackRemote) { - p.mutex.Lock() - defer p.mutex.Unlock() - - delete(p.remoteTracks, RemoteTrackId{track.ID(), webrtc_ext.RIDToSimulcastLayer(track.RID())}) -} - -func (p *PeerState) GetRemoteTrack(id string, simulcast webrtc_ext.SimulcastLayer) *webrtc.TrackRemote { - p.mutex.Lock() - defer p.mutex.Unlock() - - return p.remoteTracks[RemoteTrackId{id, simulcast}] + return &PeerState{} } func (p *PeerState) SetDataChannel(dc *webrtc.DataChannel) { diff --git a/pkg/peer/webrtc_callbacks.go b/pkg/peer/webrtc_callbacks.go index 31288b8..7bac093 100644 --- a/pkg/peer/webrtc_callbacks.go +++ b/pkg/peer/webrtc_callbacks.go @@ -1,7 +1,6 @@ package peer import ( - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -9,15 +8,8 @@ import ( // A callback that is called once we receive first RTP packets from a track, i.e. // we call this function each time a new track is received. func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - // Construct a new track info assuming that there is no simulcast. - trackInfo := webrtc_ext.TrackInfoFromTrack(remoteTrack) - - switch trackInfo.Kind { - case webrtc.RTPCodecTypeVideo: - p.handleNewVideoTrack(trackInfo, remoteTrack, receiver) - case webrtc.RTPCodecTypeAudio: - p.handleNewAudioTrack(trackInfo, remoteTrack, receiver) - } + p.logger.WithField("track", remoteTrack).Debug("RTP track received") + p.sink.Send(NewTrackPublished{remoteTrack}) } // A callback that is called once we receive an ICE candidate for this peer connection. From bb087d9a2fd843a2919095fcdfdd35e7366a63e6 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 13 Feb 2023 21:34:26 +0100 Subject: [PATCH 4/7] track: implement new handling of published tracks Now published track has its own package that encapsulated the logic related to the published tracks and manages the lifetime of the track and its subscriptions. This means that from now on, each published track (i.e. each Pion's `TrackRemote`) is handled separately from other tracks and by its own go-routine, which means that handling of packets that belong to separate tracks or even separate simulcast layers on a single track, are processed in parallel. --- pkg/conference/participant/tracker.go | 167 +++++------ pkg/conference/peer_message_processing.go | 16 +- pkg/conference/processing.go | 4 +- pkg/conference/publisher/publisher.go | 11 +- pkg/conference/start.go | 23 +- pkg/conference/state.go | 21 +- pkg/conference/subscription/video.go | 17 +- pkg/conference/track/internal.go | 103 +++++++ pkg/conference/track/keyframe.go | 42 +++ .../track.go => track/simulcast.go} | 53 +--- pkg/conference/track/track.go | 267 ++++++++++++++++++ .../{participant => track}/track_test.go | 36 +-- 12 files changed, 559 insertions(+), 201 deletions(-) create mode 100644 pkg/conference/track/internal.go create mode 100644 pkg/conference/track/keyframe.go rename pkg/conference/{participant/track.go => track/simulcast.go} (60%) create mode 100644 pkg/conference/track/track.go rename pkg/conference/{participant => track}/track_test.go (73%) diff --git a/pkg/conference/participant/tracker.go b/pkg/conference/participant/tracker.go index 555842d..4467bf4 100644 --- a/pkg/conference/participant/tracker.go +++ b/pkg/conference/participant/tracker.go @@ -3,24 +3,34 @@ package participant import ( "fmt" - "github.com/matrix-org/waterfall/pkg/conference/subscription" + pub "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" - "golang.org/x/exp/slices" ) +type TrackStoppedMessage struct { + TrackID pub.TrackID + OwnerID ID +} + // Tracks participants and their corresponding tracks. // These are grouped together as the field in this structure must be kept synchronized. type Tracker struct { participants map[ID]*Participant - publishedTracks map[TrackID]*PublishedTrack + publishedTracks map[pub.TrackID]*pub.PublishedTrack[ID] + + publishedTrackStopped chan<- TrackStoppedMessage + conferenceEnded <-chan struct{} } -func NewParticipantTracker() *Tracker { +func NewParticipantTracker(conferenceEnded <-chan struct{}) (*Tracker, <-chan TrackStoppedMessage) { + publishedTrackStopped := make(chan TrackStoppedMessage) return &Tracker{ - participants: make(map[ID]*Participant), - publishedTracks: make(map[TrackID]*PublishedTrack), - } + participants: make(map[ID]*Participant), + publishedTracks: make(map[pub.TrackID]*pub.PublishedTrack[ID]), + publishedTrackStopped: publishedTrackStopped, + conferenceEnded: conferenceEnded, + }, publishedTrackStopped } // Adds a new participant in the list. @@ -60,9 +70,9 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // Remove the participant's tracks from all participants who might have subscribed to them. streamIdentifiers := make(map[string]bool) for trackID, track := range t.publishedTracks { - if track.Owner == participantID { + if track.Owner() == participantID { // Odd way to add to a set in Go. - streamIdentifiers[track.Info.StreamID] = true + streamIdentifiers[track.Info().StreamID] = true t.RemovePublishedTrack(trackID) } } @@ -70,10 +80,7 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // Go over all subscriptions and remove the participant from them. // TODO: Perhaps we could simply react to the subscrpitions dying and remove them from the list. for _, publishedTrack := range t.publishedTracks { - if subscription, found := publishedTrack.Subscriptions[participantID]; found { - subscription.Unsubscribe() - delete(publishedTrack.Subscriptions, participantID) - } + publishedTrack.Unsubscribe(participantID) } return streamIdentifiers @@ -83,138 +90,98 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // that has been published and that we must take into account from now on. func (t *Tracker) AddPublishedTrack( participantID ID, - info webrtc_ext.TrackInfo, - simulcast webrtc_ext.SimulcastLayer, - metadata TrackMetadata, - outputTrack *webrtc.TrackLocalStaticRTP, -) { - // If this is a new track, let's add it to the list of published and inform participants. - track, found := t.publishedTracks[info.TrackID] - if !found { - layers := []webrtc_ext.SimulcastLayer{} - if simulcast != webrtc_ext.SimulcastLayerNone { - layers = append(layers, simulcast) - } + track *webrtc.TrackRemote, + metadata pub.TrackMetadata, +) error { + participant := t.participants[participantID] + if participant == nil { + return fmt.Errorf("participant %s does not exist", participantID) + } - t.publishedTracks[info.TrackID] = &PublishedTrack{ - Owner: participantID, - Info: info, - Layers: layers, - Metadata: metadata, - OutputTrack: outputTrack, - Subscriptions: make(map[ID]subscription.Subscription), + // If this is a new track, let's add it to the list of published and inform participants. + if published, found := t.publishedTracks[track.ID()]; found { + if err := published.AddPublisher(track); err != nil { + return err } - return + return nil } - // If it's just a new layer, let's add it to the list of layers of the existing published track. - fn := func(layer webrtc_ext.SimulcastLayer) bool { return layer == simulcast } - if simulcast != webrtc_ext.SimulcastLayerNone && slices.IndexFunc(track.Layers, fn) == -1 { - track.Layers = append(track.Layers, simulcast) - t.publishedTracks[info.TrackID] = track + published, err := pub.NewPublishedTrack( + participantID, + participant.Peer.RequestKeyFrame, + track, + metadata, + participant.Logger, + ) + if err != nil { + return err } + + // Wait for the track to complete and inform the conference about it. + go func() { + // Wait for the track to complete. + <-published.Done() + + // Inform the conference that the track is gone. Or stop the go-routine if the conference stopped. + select { + case t.publishedTrackStopped <- TrackStoppedMessage{track.ID(), participantID}: + case <-t.conferenceEnded: + } + }() + + t.publishedTracks[track.ID()] = published + return nil } // Iterates over published tracks and calls a closure upon each track info. func (t *Tracker) ForEachPublishedTrackInfo(fn func(ID, webrtc_ext.TrackInfo)) { for _, track := range t.publishedTracks { - fn(track.Owner, track.Info) + fn(track.Owner(), track.Info()) } } // Updates metadata associated with a given track. -func (t *Tracker) UpdatePublishedTrackMetadata(id TrackID, metadata TrackMetadata) { +func (t *Tracker) UpdatePublishedTrackMetadata(id pub.TrackID, metadata pub.TrackMetadata) { if track, found := t.publishedTracks[id]; found { - track.Metadata = metadata + track.SetMetadata(metadata) t.publishedTracks[id] = track } } // Informs the tracker that one of the previously published tracks is gone. -func (t *Tracker) RemovePublishedTrack(id TrackID) { +func (t *Tracker) RemovePublishedTrack(id pub.TrackID) { if publishedTrack, found := t.publishedTracks[id]; found { - // Iterate over all subscriptions and end them. - for subscriberID, subscription := range publishedTrack.Subscriptions { - subscription.Unsubscribe() - delete(publishedTrack.Subscriptions, subscriberID) - } - + publishedTrack.Stop() delete(t.publishedTracks, id) } } // Subscribes a given participant to the track. -func (t *Tracker) Subscribe(participantID ID, trackID TrackID, requirements TrackMetadata) error { +func (t *Tracker) Subscribe(participantID ID, trackID pub.TrackID, requirements pub.TrackMetadata) error { // Check if the participant exists that wants to subscribe exists. participant := t.participants[participantID] if participant == nil { return fmt.Errorf("participant %s does not exist", participantID) } - // Check if the track that we want to subscribe to exists. + // Check if the track that we want to subscribe exists. published := t.publishedTracks[trackID] if published == nil { return fmt.Errorf("track %s does not exist", trackID) } - // Calculate the desired simulcast layer. - desiredLayer := published.GetOptimalLayer(requirements.MaxWidth, requirements.MaxHeight) - - // If the subscription exists, let's see if we need to update it. - if sub := published.Subscriptions[participantID]; sub != nil { - if sub.Simulcast() != desiredLayer { - sub.SwitchLayer(desiredLayer) - return nil - } - - return fmt.Errorf("subscription already exists and up-to-date") - } - - // Find the owner of the track that we're trying to subscribe to. - owner := t.participants[published.Owner] - if owner == nil { - return fmt.Errorf("owner of the track %s does not exist", published.Info.TrackID) - } - - var ( - sub subscription.Subscription - err error - ) - - // Subscription does not exist, so let's create it. - switch published.Info.Kind { - case webrtc.RTPCodecTypeVideo: - sub, err = subscription.NewVideoSubscription( - published.Info, - desiredLayer, - participant.Peer, - func(track webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error { - return owner.Peer.RequestKeyFrame(track, simulcast) - }, - participant.Logger, - ) - case webrtc.RTPCodecTypeAudio: - sub, err = subscription.NewAudioSubscription(published.OutputTrack, participant.Peer) - } - - // If there was an error, let's return it. - if err != nil { + // Subscribe to the track. + if err := published.Subscribe(participantID, participant.Peer, requirements, participant.Logger); err != nil { return err } - // Add the subscription to the list of subscriptions. - published.Subscriptions[participantID] = sub - return nil } // Unsubscribes a given `participantID` from the track. -func (t *Tracker) Unsubscribe(participantID ID, trackID TrackID) { +func (t *Tracker) Unsubscribe(participantID ID, trackID pub.TrackID) { if published := t.publishedTracks[trackID]; published != nil { - if sub := published.Subscriptions[participantID]; sub != nil { - sub.Unsubscribe() - delete(published.Subscriptions, participantID) - } + published.Unsubscribe(participantID) } } diff --git a/pkg/conference/peer_message_processing.go b/pkg/conference/peer_message_processing.go index 72494bd..d30d888 100644 --- a/pkg/conference/peer_message_processing.go +++ b/pkg/conference/peer_message_processing.go @@ -2,6 +2,7 @@ package conference import ( "github.com/matrix-org/waterfall/pkg/conference/participant" + published "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" "maunium.net/go/mautrix/event" @@ -23,19 +24,20 @@ func (c *Conference) processLeftTheCallMessage(sender participant.ID, msg peer.L } func (c *Conference) processNewTrackPublishedMessage(sender participant.ID, msg peer.NewTrackPublished) { - c.newLogger(sender).Infof("Published new track: %s (%v)", msg.TrackID, msg.SimulcastLayer) + id := msg.RemoteTrack.ID() + c.newLogger(sender).Infof("Published new track: %s (%v)", id, msg.RemoteTrack.RID()) // Find metadata for a given track. - trackMetadata := streamIntoTrackMetadata(c.streamsMetadata)[msg.TrackID] + trackMetadata := streamIntoTrackMetadata(c.streamsMetadata)[id] // If a new track has been published, we inform everyone about new track available. - c.tracker.AddPublishedTrack(sender, msg.TrackInfo, msg.SimulcastLayer, trackMetadata, msg.OutputTrack) + c.tracker.AddPublishedTrack(sender, msg.RemoteTrack, trackMetadata) c.resendMetadataToAllExcept(sender) } -func (c *Conference) processPublishedTrackFailedMessage(sender participant.ID, msg peer.PublishedTrackFailed) { - c.newLogger(sender).Infof("Failed published track: %s", msg.TrackID) - c.tracker.RemovePublishedTrack(msg.TrackID) +func (c *Conference) processPublishedTrackFailedMessage(sender participant.ID, trackID published.TrackID) { + c.newLogger(sender).Infof("Failed published track: %s", trackID) + c.tracker.RemovePublishedTrack(trackID) c.resendMetadataToAllExcept(sender) } @@ -159,7 +161,7 @@ func (c *Conference) processTrackSubscriptionMessage( for _, track := range msg.Subscribe { p.Logger.Debugf("Subscribing to track %s", track.TrackID) - requirements := participant.TrackMetadata{track.Width, track.Height} + requirements := published.TrackMetadata{track.Width, track.Height} if err := c.tracker.Subscribe(p.ID, track.TrackID, requirements); err != nil { p.Logger.Errorf("Failed to subscribe to track %s: %v", track.TrackID, err) continue diff --git a/pkg/conference/processing.go b/pkg/conference/processing.go index a8da1f9..06208b1 100644 --- a/pkg/conference/processing.go +++ b/pkg/conference/processing.go @@ -21,6 +21,8 @@ func (c *Conference) processMessages(signalDone chan struct{}) { c.processPeerMessage(msg) case msg := <-c.matrixEvents: c.processMatrixMessage(msg) + case msg := <-c.publishedTrackStopped: + c.processPublishedTrackFailedMessage(msg.OwnerID, msg.TrackID) } // If there are no more participants, stop the conference. @@ -42,8 +44,6 @@ func (c *Conference) processPeerMessage(message channel.Message[participant.ID, c.processLeftTheCallMessage(message.Sender, msg) case peer.NewTrackPublished: c.processNewTrackPublishedMessage(message.Sender, msg) - case peer.PublishedTrackFailed: - c.processPublishedTrackFailedMessage(message.Sender, msg) case peer.NewICECandidate: c.processNewICECandidateMessage(message.Sender, msg) case peer.ICEGatheringComplete: diff --git a/pkg/conference/publisher/publisher.go b/pkg/conference/publisher/publisher.go index b2d637a..fdd102f 100644 --- a/pkg/conference/publisher/publisher.go +++ b/pkg/conference/publisher/publisher.go @@ -2,10 +2,10 @@ package publisher import ( "errors" - "fmt" "sync" "github.com/pion/rtp" + "github.com/sirupsen/logrus" ) var ErrSubscriptionExists = errors.New("subscription already exists") @@ -22,6 +22,8 @@ type Track interface { // An abstract publisher that reads the packets from the track and forwards them to all subscribers. type Publisher struct { + logger *logrus.Entry + mu sync.Mutex track Track subscriptions map[Subscription]struct{} @@ -30,11 +32,13 @@ type Publisher struct { func NewPublisher( track Track, stop <-chan struct{}, + log *logrus.Entry, ) (*Publisher, <-chan struct{}) { // Create a done channel, so that we can signal the caller when we're done. done := make(chan struct{}) publisher := &Publisher{ + logger: log, track: track, subscriptions: make(map[Subscription]struct{}), } @@ -50,7 +54,7 @@ func NewPublisher( return default: if err := publisher.forwardPacket(); err != nil { - fmt.Println("failed to write to subscribers: ", err) + log.Errorf("failed to read the frame from the track %s", err) return } } @@ -104,8 +108,7 @@ func (p *Publisher) forwardPacket() error { // Write the packet to all subscribers. for subscription := range p.subscriptions { if err := subscription.WriteRTP(*packet); err != nil { - fmt.Println("failed to write to subscriber: ", err) - delete(p.subscriptions, subscription) + p.logger.Warnf("packet dropped on the subscription: %s", err) } } diff --git a/pkg/conference/start.go b/pkg/conference/start.go index cc749c3..521a97d 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -38,16 +38,20 @@ func StartConference( userID id.UserID, inviteEvent *event.CallInviteEventContent, ) (<-chan struct{}, error) { + signalDone := make(chan struct{}) + + tracker, publishedTrackStopped := participant.NewParticipantTracker(signalDone) conference := &Conference{ - id: confID, - config: config, - connectionFactory: peerConnectionFactory, - logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), - matrixWorker: newMatrixWorker(signaling), - tracker: *participant.NewParticipantTracker(), - streamsMetadata: make(event.CallSDPStreamMetadata), - peerMessages: make(chan channel.Message[participant.ID, peer.MessageContent], 100), - matrixEvents: matrixEvents, + id: confID, + config: config, + connectionFactory: peerConnectionFactory, + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + matrixWorker: newMatrixWorker(signaling), + tracker: tracker, + streamsMetadata: make(event.CallSDPStreamMetadata), + peerMessages: make(chan channel.Message[participant.ID, peer.MessageContent], 100), + matrixEvents: matrixEvents, + publishedTrackStopped: publishedTrackStopped, } participantID := participant.ID{UserID: userID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} @@ -56,7 +60,6 @@ func StartConference( } // Start conference "main loop". - signalDone := make(chan struct{}) go conference.processMessages(signalDone) return signalDone, nil diff --git a/pkg/conference/state.go b/pkg/conference/state.go index bf62f18..8c58f05 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -3,6 +3,7 @@ package conference import ( "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/conference/participant" + published "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/sirupsen/logrus" @@ -11,19 +12,19 @@ import ( // A single conference. Call and conference mean the same in context of Matrix. type Conference struct { - id string - config Config - logger *logrus.Entry - conferenceDone chan<- struct{} + id string + config Config + logger *logrus.Entry connectionFactory *webrtc_ext.PeerConnectionFactory matrixWorker *matrixWorker - tracker participant.Tracker + tracker *participant.Tracker streamsMetadata event.CallSDPStreamMetadata - peerMessages chan channel.Message[participant.ID, peer.MessageContent] - matrixEvents <-chan MatrixMessage + peerMessages chan channel.Message[participant.ID, peer.MessageContent] + matrixEvents <-chan MatrixMessage + publishedTrackStopped <-chan participant.TrackStoppedMessage } func (c *Conference) getParticipant(id participant.ID) *participant.Participant { @@ -114,11 +115,11 @@ func (c *Conference) updateMetadata(metadata event.CallSDPStreamMetadata) { func streamIntoTrackMetadata( streamMetadata event.CallSDPStreamMetadata, -) map[participant.TrackID]participant.TrackMetadata { - tracksMetadata := make(map[participant.TrackID]participant.TrackMetadata) +) map[published.TrackID]published.TrackMetadata { + tracksMetadata := make(map[published.TrackID]published.TrackMetadata) for _, metadata := range streamMetadata { for id, track := range metadata.Tracks { - tracksMetadata[id] = participant.TrackMetadata{ + tracksMetadata[id] = published.TrackMetadata{ MaxWidth: track.Width, MaxHeight: track.Height, } diff --git a/pkg/conference/subscription/video.go b/pkg/conference/subscription/video.go index 9a21892..f8d58f1 100644 --- a/pkg/conference/subscription/video.go +++ b/pkg/conference/subscription/video.go @@ -16,7 +16,7 @@ import ( "github.com/sirupsen/logrus" ) -type RequestKeyFrameFn = func(track webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error +type RequestKeyFrameFn = func(simulcast webrtc_ext.SimulcastLayer) error type VideoSubscription struct { rtpSender *webrtc.RTPSender @@ -71,12 +71,14 @@ func NewVideoSubscription( // Configure the worker for the subscription. workerConfig := worker.Config[rtp.Packet]{ - ChannelSize: 32, - Timeout: 3 * time.Second, + ChannelSize: 16, // We really don't need a large buffer here, just to account for spikes. + Timeout: 3 * time.Second, // When do we assume the subscription is stalled. OnTimeout: func() { layer := webrtc_ext.SimulcastLayer(subscription.currentLayer.Load()) + // TODO: At this point we probably need to send some message back + // to the conference and switch the quality of remove the + // subscription. This must not happen under normal circumstances. logger.Warnf("No RTP on subscription %s (%s)", subscription.info.TrackID, layer) - subscription.requestKeyFrame() }, OnTask: workerState.handlePacket, } @@ -107,7 +109,7 @@ func (s *VideoSubscription) WriteRTP(packet rtp.Packet) error { func (s *VideoSubscription) SwitchLayer(simulcast webrtc_ext.SimulcastLayer) { s.logger.Infof("Switching layer on %s to %s", s.info.TrackID, simulcast) s.currentLayer.Store(int32(simulcast)) - s.requestKeyFrame() + s.requestKeyFrameFn(simulcast) } func (s *VideoSubscription) TrackInfo() webrtc_ext.TrackInfo { @@ -143,10 +145,7 @@ func (s *VideoSubscription) readRTCP() { } func (s *VideoSubscription) requestKeyFrame() { - layer := webrtc_ext.SimulcastLayer(s.currentLayer.Load()) - if err := s.requestKeyFrameFn(s.info, layer); err != nil { - s.logger.Errorf("Failed to request key frame: %s", err) - } + s.requestKeyFrameFn(webrtc_ext.SimulcastLayer(s.currentLayer.Load())) } // Internal state of a worker that runs in its own goroutine. diff --git a/pkg/conference/track/internal.go b/pkg/conference/track/internal.go new file mode 100644 index 0000000..e9ee51f --- /dev/null +++ b/pkg/conference/track/internal.go @@ -0,0 +1,103 @@ +package track + +import ( + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/matrix-org/waterfall/pkg/worker" + "github.com/pion/webrtc/v3" +) + +type trackOwner[SubscriberID comparable] struct { + owner SubscriberID + requestKeyFrame func(track *webrtc.TrackRemote) error +} + +type audioTrack struct { + // The sink of this audio track packets. + outputTrack *webrtc.TrackLocalStaticRTP +} + +type videoTrack struct { + // Publisher's of each video layer. + publishers map[webrtc_ext.SimulcastLayer]*publisher.Publisher + // Key frame request handler. + keyframeHandler *worker.Worker[webrtc_ext.SimulcastLayer] +} + +// Forward audio packets from the source track to the destination track. +func forward(sender *webrtc.TrackRemote, receiver *webrtc.TrackLocalStaticRTP, stop <-chan struct{}) error { + for { + // Read the data from the remote track. + packet, _, readErr := sender.ReadRTP() + if readErr != nil { + return readErr + } + + // Write the data to the local track. + if writeErr := receiver.WriteRTP(packet); writeErr != nil { + return writeErr + } + + // Check if we need to stop processing packets. + select { + case <-stop: + return nil + default: + } + } +} + +func (p *PublishedTrack[SubscriberID]) addVideoPublisher(track *webrtc.TrackRemote) { + pub, done := publisher.NewPublisher(&publisher.RemoteTrack{track}, p.stopPublishers, p.logger) + simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) + p.video.publishers[simulcast] = pub + + // Listen on `done` and remove the track once it's done. + p.activePublishers.Add(1) + go func() { + defer p.activePublishers.Done() + <-done + + p.mutex.Lock() + defer p.mutex.Unlock() + + // Remove the publisher once it's gone. + delete(p.video.publishers, simulcast) + + // Find any other available layer, so that we can switch subscriptions that lost their publisher + // to a new publisher (at least they'll get some data). + var ( + availableLayer webrtc_ext.SimulcastLayer + availablePublisher *publisher.Publisher + ) + for layer, pub := range p.video.publishers { + availableLayer = layer + availablePublisher = pub + break + } + + // Now iterate over all subscriptions and find those that are now lost due to the publisher being away. + for subID, sub := range p.subscriptions { + if sub.Simulcast() == simulcast { + // If there is some other publisher on a different layer, let's switch to it + if availablePublisher != nil { + sub.SwitchLayer(availableLayer) + pub.AddSubscription(sub) + } else { + // Otherwise, let's just remove the subscription. + sub.Unsubscribe() + delete(p.subscriptions, subID) + } + } + } + }() +} + +func (p *PublishedTrack[SubscriberID]) isClosed() bool { + select { + case <-p.done: + return true + default: + return false + } +} diff --git a/pkg/conference/track/keyframe.go b/pkg/conference/track/keyframe.go new file mode 100644 index 0000000..37fdaca --- /dev/null +++ b/pkg/conference/track/keyframe.go @@ -0,0 +1,42 @@ +package track + +import ( + "fmt" + + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/pion/webrtc/v3" +) + +func (p *PublishedTrack[SubscriberID]) handleKeyFrameRequest(simulcast webrtc_ext.SimulcastLayer) error { + publisher := p.getPublisher(simulcast) + if publisher == nil { + return fmt.Errorf("publisher with simulcast %s not found", simulcast) + } + + track, err := extractRemoteTrack(publisher) + if err != nil { + return err + } + + return p.owner.requestKeyFrame(track) +} + +func (p *PublishedTrack[SubscriberID]) getPublisher(simulcast webrtc_ext.SimulcastLayer) *publisher.Publisher { + p.mutex.Lock() + defer p.mutex.Unlock() + + // Get the track that we need to request a key frame for. + return p.video.publishers[simulcast] +} + +func extractRemoteTrack(pub *publisher.Publisher) (*webrtc.TrackRemote, error) { + // Get the track that we need to request a key frame for. + track := pub.GetTrack() + remoteTrack, ok := track.(*publisher.RemoteTrack) + if !ok { + return nil, fmt.Errorf("not a remote track in publisher") + } + + return remoteTrack.Track, nil +} diff --git a/pkg/conference/participant/track.go b/pkg/conference/track/simulcast.go similarity index 60% rename from pkg/conference/participant/track.go rename to pkg/conference/track/simulcast.go index fd68cc6..b481f9e 100644 --- a/pkg/conference/participant/track.go +++ b/pkg/conference/track/simulcast.go @@ -1,41 +1,28 @@ -package participant +package track import ( - "github.com/matrix-org/waterfall/pkg/conference/subscription" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/webrtc/v3" - "golang.org/x/exp/slices" ) -type TrackID = string - -// Represents a track that a peer has published (has already started sending to the SFU). -type PublishedTrack struct { - // Owner of a published track. - Owner ID - // Info about the track. - Info webrtc_ext.TrackInfo - // Available simulcast Layers. - Layers []webrtc_ext.SimulcastLayer - // Track metadata. - Metadata TrackMetadata - // Output track (if any). I.e. a track that would contain all RTP packets - // of the given published track. Currently only audio tracks will have it. - OutputTrack *webrtc.TrackLocalStaticRTP - // All available subscriptions for this particular track. - Subscriptions map[ID]subscription.Subscription +// Metadata that we have received about this track from a user. +// This metadata is only set for video tracks at the moment. +type TrackMetadata struct { + MaxWidth, MaxHeight int } // Calculate the layer that we can use based on the requirements passed as parameters and available layers. -func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) webrtc_ext.SimulcastLayer { - // Audio track. For them we don't have any simulcast. We also don't have any simulcast for video - // if there was no simulcast enabled at all. - if p.Info.Kind == webrtc.RTPCodecTypeAudio || len(p.Layers) == 0 { +func getOptimalLayer( + layers map[webrtc_ext.SimulcastLayer]struct{}, + metadata TrackMetadata, + requestedWidth, requestedHeight int, +) webrtc_ext.SimulcastLayer { + // If we don't have any layers available, then there is no simulcast. + if _, found := layers[webrtc_ext.SimulcastLayerNone]; found || len(layers) == 0 { return webrtc_ext.SimulcastLayerNone } // Video track. Calculate the optimal layer closest to the requested resolution. - desiredLayer := calculateDesiredLayer(p.Metadata.MaxWidth, p.Metadata.MaxHeight, requestedWidth, requestedHeight) + desiredLayer := calculateDesiredLayer(metadata.MaxWidth, metadata.MaxHeight, requestedWidth, requestedHeight) // Ideally, here we would need to send an error if the desired layer is not available, but we don't // have a way to do it. So we just return the closest available layer. @@ -48,12 +35,8 @@ func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) we // More Go boilerplate. for _, desiredLayer := range priority { - layerIndex := slices.IndexFunc(p.Layers, func(simulcast webrtc_ext.SimulcastLayer) bool { - return simulcast == desiredLayer - }) - - if layerIndex != -1 { - return p.Layers[layerIndex] + if _, found := layers[desiredLayer]; found { + return desiredLayer } } @@ -62,12 +45,6 @@ func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) we return webrtc_ext.SimulcastLayerLow } -// Metadata that we have received about this track from a user. -// This metadata is only set for video tracks at the moment. -type TrackMetadata struct { - MaxWidth, MaxHeight int -} - // Calculates the optimal layer closest to the requested resolution. We assume that the full resolution is the // maximum resolution that we can get from the user. We assume that a medium quality layer is half the size of // the video (**but not half of the resolution**). I.e. medium quality is high quality divided by 4. And low diff --git a/pkg/conference/track/track.go b/pkg/conference/track/track.go new file mode 100644 index 0000000..163b21e --- /dev/null +++ b/pkg/conference/track/track.go @@ -0,0 +1,267 @@ +package track + +import ( + "fmt" + "sync" + "time" + + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/conference/subscription" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/matrix-org/waterfall/pkg/worker" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" +) + +type TrackID = string + +// Represents a track that a peer has published (has already started sending to the SFU). +type PublishedTrack[SubscriberID comparable] struct { + // Logger. + logger *logrus.Entry + // Info about the track. + info webrtc_ext.TrackInfo + // Owner of a published track. + owner trackOwner[SubscriberID] + + // We must protect the data with a mutex since we want the `PublishedTrack` to remain thread-safe. + mutex sync.Mutex + // Currently active subscriptions for this track. + subscriptions map[SubscriberID]subscription.Subscription + // Audio track data. The content will be `nil` if it's not an audio track. + audio *audioTrack + // Video track. The content will be `nil` if it's not a video track. + video *videoTrack + // Track metadata. + metadata TrackMetadata + + // Wait group for all active publishers. + activePublishers *sync.WaitGroup + // A signal to publishers **to stop** them all. + stopPublishers chan struct{} + // A aignal to inform the caller that all publishers of this track **have been stopped**. + done chan struct{} +} + +func NewPublishedTrack[SubscriberID comparable]( + ownerID SubscriberID, + requestKeyFrame func(track *webrtc.TrackRemote) error, + track *webrtc.TrackRemote, + metadata TrackMetadata, + logger *logrus.Entry, +) (*PublishedTrack[SubscriberID], error) { + published := &PublishedTrack[SubscriberID]{ + logger: logger, + info: webrtc_ext.TrackInfoFromTrack(track), + owner: trackOwner[SubscriberID]{ownerID, requestKeyFrame}, + subscriptions: make(map[SubscriberID]subscription.Subscription), + audio: &audioTrack{outputTrack: nil}, + video: &videoTrack{publishers: make(map[webrtc_ext.SimulcastLayer]*publisher.Publisher)}, + metadata: metadata, + activePublishers: &sync.WaitGroup{}, + stopPublishers: make(chan struct{}), + done: make(chan struct{}), + } + + switch published.info.Kind { + case webrtc.RTPCodecTypeAudio: + // Create a local track, all our SFU clients that are subscribed to this + // peer (publisher) wil be fed via this track. + localTrack, err := webrtc.NewTrackLocalStaticRTP( + track.Codec().RTPCodecCapability, + track.ID(), + track.StreamID(), + ) + if err != nil { + return nil, err + } + + published.audio.outputTrack = localTrack + + // Start audio publisher in a separate goroutine. + published.activePublishers.Add(1) + go func() { + defer published.activePublishers.Done() + if err := forward(track, localTrack, published.stopPublishers); err != nil { + logger.Errorf("audio publisher stopped: %s", err) + } + }() + + case webrtc.RTPCodecTypeVideo: + // Configure and start a worker to process incoming key frame requests. + workerConfig := worker.Config[webrtc_ext.SimulcastLayer]{ + ChannelSize: 16, + Timeout: 1 * time.Hour, + OnTimeout: func() {}, + OnTask: func(simulcast webrtc_ext.SimulcastLayer) { + published.handleKeyFrameRequest(simulcast) + }, + } + + worker := worker.StartWorker[webrtc_ext.SimulcastLayer](workerConfig) + published.video.keyframeHandler = worker + + // Start video publisher. + published.addVideoPublisher(track) + } + + // Wait for all publishers to stop. + go func() { + defer close(published.done) + published.activePublishers.Wait() + }() + + return published, nil +} + +// Adds a new publisher to the existing `PublishedTrack`, this happens if we +// have multiple qualities (layers) on a single track. +func (p *PublishedTrack[SubscriberID]) AddPublisher(track *webrtc.TrackRemote) error { + if p.isClosed() { + return fmt.Errorf("track is already closed") + } + + info := webrtc_ext.TrackInfoFromTrack(track) + if info.TrackID != p.info.TrackID || p.info.Kind.String() != info.Kind.String() { + return fmt.Errorf("track mismatch") + } + + // Such publisher already exists. Let's replace the track that provides frames with a new one. + simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) + + // Lock the mutex since we access the publishers from multiple threads. + p.mutex.Lock() + defer p.mutex.Unlock() + + // If the publisher for this track already exists, let's replace the track. This may happen during + // the negotiation when the SSRC changes and Pion fires a new track for the track that has already + // been published. + if pub := p.video.publishers[simulcast]; pub != nil { + pub.ReplaceTrack(&publisher.RemoteTrack{track}) + return nil + } + + // Add a publisher and start polling it. + p.addVideoPublisher(track) + return nil +} + +// Stops the published track and all related publishers. You should not use the +// `PublishedTrack` after calling this method. +func (p *PublishedTrack[SubscriberID]) Stop() { + // Command all publishers to stop, unless already stopped. + if !p.isClosed() { + close(p.stopPublishers) + } +} + +// Create a new subscription for a given subscriber or update the existing one if necessary. +func (p *PublishedTrack[SubscriberID]) Subscribe( + subscriberID SubscriberID, + controller subscription.SubscriptionController, + requirements TrackMetadata, + logger *logrus.Entry, +) error { + if p.isClosed() { + return fmt.Errorf("track is already closed") + } + + // Lock the mutex as we access subscriptions and publishers from multiple threads. + p.mutex.Lock() + defer p.mutex.Unlock() + + // Let's calculate the desired simulcast layer (if any). + var layer webrtc_ext.SimulcastLayer + if p.info.Kind == webrtc.RTPCodecTypeVideo { + layers := make(map[webrtc_ext.SimulcastLayer]struct{}, len(p.video.publishers)) + for key := range p.video.publishers { + layers[key] = struct{}{} + } + layer = getOptimalLayer(layers, p.metadata, requirements.MaxWidth, requirements.MaxHeight) + } + + // If the subscription exists, let's see if we need to update it. + if sub := p.subscriptions[subscriberID]; sub != nil { + currentLayer := sub.Simulcast() + + // If we do, let's switch the layer. + if currentLayer != layer { + p.video.publishers[currentLayer].RemoveSubscription(sub) + sub.SwitchLayer(layer) + p.video.publishers[layer].AddSubscription(sub) + } + + // Subsription is up-to-date, nothing to change. + return nil + } + + var ( + sub subscription.Subscription + err error + ) + + // Subscription does not exist, so let's create it. + switch p.info.Kind { + case webrtc.RTPCodecTypeVideo: + handler := func(simulcast webrtc_ext.SimulcastLayer) error { + return p.video.keyframeHandler.Send(simulcast) + } + sub, err = subscription.NewVideoSubscription(p.info, layer, controller, handler, logger) + case webrtc.RTPCodecTypeAudio: + sub, err = subscription.NewAudioSubscription(p.audio.outputTrack, controller) + } + + // If there was an error, let's return it. + if err != nil { + return err + } + + // Add the subscription to the list of subscriptions. + p.subscriptions[subscriberID] = sub + + // And if it's a video subscription, add it to the list of subscriptions that get the feed from the publisher. + if p.info.Kind == webrtc.RTPCodecTypeVideo { + p.video.publishers[layer].AddSubscription(sub) + } + + return nil +} + +// Remove subscriptions with a given subscriber id. +func (p *PublishedTrack[SubscriberID]) Unsubscribe(subscriberID SubscriberID) { + p.mutex.Lock() + defer p.mutex.Unlock() + + if sub := p.subscriptions[subscriberID]; sub != nil { + sub.Unsubscribe() + delete(p.subscriptions, subscriberID) + + if p.info.Kind == webrtc.RTPCodecTypeVideo { + p.video.publishers[sub.Simulcast()].RemoveSubscription(sub) + } + } +} + +func (p *PublishedTrack[SubscriberID]) Owner() SubscriberID { + return p.owner.owner +} + +func (p *PublishedTrack[SubscriberID]) Info() webrtc_ext.TrackInfo { + return p.info +} + +func (p *PublishedTrack[SubscriberID]) Done() <-chan struct{} { + return p.done +} + +func (p *PublishedTrack[SubscriberID]) Metadata() TrackMetadata { + p.mutex.Lock() + defer p.mutex.Unlock() + return p.metadata +} + +func (p *PublishedTrack[SubscriberID]) SetMetadata(metadata TrackMetadata) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.metadata = metadata +} diff --git a/pkg/conference/participant/track_test.go b/pkg/conference/track/track_test.go similarity index 73% rename from pkg/conference/participant/track_test.go rename to pkg/conference/track/track_test.go index 3d8f9a1..9825660 100644 --- a/pkg/conference/participant/track_test.go +++ b/pkg/conference/track/track_test.go @@ -1,11 +1,9 @@ -package participant_test +package track //nolint:testpackage import ( "testing" - "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/webrtc/v3" ) func TestGetOptimalLayer(t *testing.T) { @@ -43,33 +41,29 @@ func TestGetOptimalLayer(t *testing.T) { {layers(high), 1280, 720, 200, 200, high}, } - mock := participant.PublishedTrack{ - Info: webrtc_ext.TrackInfo{ - Kind: webrtc.RTPCodecTypeVideo, - }, - } - for _, c := range cases { - mock.Layers = c.availableLayers - mock.Metadata.MaxWidth = c.fullWidth - mock.Metadata.MaxHeight = c.fullHeight + metadata := TrackMetadata{ + MaxWidth: c.fullWidth, + MaxHeight: c.fullHeight, + } - optimalLayer := mock.GetOptimalLayer(c.desiredWidth, c.desiredHeight) + layers := make(map[webrtc_ext.SimulcastLayer]struct{}, len(c.availableLayers)) + for _, layer := range c.availableLayers { + layers[layer] = struct{}{} + } + + optimalLayer := getOptimalLayer(layers, metadata, c.desiredWidth, c.desiredHeight) if optimalLayer != c.expectedOptimalLayer { t.Errorf("Expected optimal layer %s, got %s", c.expectedOptimalLayer, optimalLayer) } } } -func TestGetOptimalLayerAudio(t *testing.T) { - mock := participant.PublishedTrack{ - Info: webrtc_ext.TrackInfo{ - Kind: webrtc.RTPCodecTypeAudio, - }, - } +func TestGetOptimalLayerNone(t *testing.T) { + layers := make(map[webrtc_ext.SimulcastLayer]struct{}) + metadata := TrackMetadata{} - mock.Layers = []webrtc_ext.SimulcastLayer{webrtc_ext.SimulcastLayerLow} - if mock.GetOptimalLayer(100, 100) != webrtc_ext.SimulcastLayerNone { + if getOptimalLayer(layers, metadata, 100, 100) != webrtc_ext.SimulcastLayerNone { t.Fatal("Expected no simulcast layer for audio") } } From e74d3d821e5fbe1759caf3cd6c9c9a6d986dafff Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 15 Feb 2023 17:42:35 +0100 Subject: [PATCH 5/7] Update pkg/conference/track/internal.go Co-authored-by: David Baker --- pkg/conference/track/internal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/conference/track/internal.go b/pkg/conference/track/internal.go index e9ee51f..4bf959b 100644 --- a/pkg/conference/track/internal.go +++ b/pkg/conference/track/internal.go @@ -18,7 +18,7 @@ type audioTrack struct { } type videoTrack struct { - // Publisher's of each video layer. + // Publishers of each video layer. publishers map[webrtc_ext.SimulcastLayer]*publisher.Publisher // Key frame request handler. keyframeHandler *worker.Worker[webrtc_ext.SimulcastLayer] From 5e53c5331bc759960511af4398eef8f2a6f6ee98 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Sat, 18 Feb 2023 23:50:51 +0100 Subject: [PATCH 6/7] tracker: remove `pub` alias for `track` package --- pkg/conference/participant/tracker.go | 32 +++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pkg/conference/participant/tracker.go b/pkg/conference/participant/tracker.go index 4467bf4..7b0a402 100644 --- a/pkg/conference/participant/tracker.go +++ b/pkg/conference/participant/tracker.go @@ -3,13 +3,13 @@ package participant import ( "fmt" - pub "github.com/matrix-org/waterfall/pkg/conference/track" + "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" ) type TrackStoppedMessage struct { - TrackID pub.TrackID + TrackID track.TrackID OwnerID ID } @@ -17,7 +17,7 @@ type TrackStoppedMessage struct { // These are grouped together as the field in this structure must be kept synchronized. type Tracker struct { participants map[ID]*Participant - publishedTracks map[pub.TrackID]*pub.PublishedTrack[ID] + publishedTracks map[track.TrackID]*track.PublishedTrack[ID] publishedTrackStopped chan<- TrackStoppedMessage conferenceEnded <-chan struct{} @@ -27,7 +27,7 @@ func NewParticipantTracker(conferenceEnded <-chan struct{}) (*Tracker, <-chan Tr publishedTrackStopped := make(chan TrackStoppedMessage) return &Tracker{ participants: make(map[ID]*Participant), - publishedTracks: make(map[pub.TrackID]*pub.PublishedTrack[ID]), + publishedTracks: make(map[track.TrackID]*track.PublishedTrack[ID]), publishedTrackStopped: publishedTrackStopped, conferenceEnded: conferenceEnded, }, publishedTrackStopped @@ -90,8 +90,8 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // that has been published and that we must take into account from now on. func (t *Tracker) AddPublishedTrack( participantID ID, - track *webrtc.TrackRemote, - metadata pub.TrackMetadata, + remoteTrack *webrtc.TrackRemote, + metadata track.TrackMetadata, ) error { participant := t.participants[participantID] if participant == nil { @@ -99,18 +99,18 @@ func (t *Tracker) AddPublishedTrack( } // If this is a new track, let's add it to the list of published and inform participants. - if published, found := t.publishedTracks[track.ID()]; found { - if err := published.AddPublisher(track); err != nil { + if published, found := t.publishedTracks[remoteTrack.ID()]; found { + if err := published.AddPublisher(remoteTrack); err != nil { return err } return nil } - published, err := pub.NewPublishedTrack( + published, err := track.NewPublishedTrack( participantID, participant.Peer.RequestKeyFrame, - track, + remoteTrack, metadata, participant.Logger, ) @@ -125,12 +125,12 @@ func (t *Tracker) AddPublishedTrack( // Inform the conference that the track is gone. Or stop the go-routine if the conference stopped. select { - case t.publishedTrackStopped <- TrackStoppedMessage{track.ID(), participantID}: + case t.publishedTrackStopped <- TrackStoppedMessage{remoteTrack.ID(), participantID}: case <-t.conferenceEnded: } }() - t.publishedTracks[track.ID()] = published + t.publishedTracks[remoteTrack.ID()] = published return nil } @@ -142,7 +142,7 @@ func (t *Tracker) ForEachPublishedTrackInfo(fn func(ID, webrtc_ext.TrackInfo)) { } // Updates metadata associated with a given track. -func (t *Tracker) UpdatePublishedTrackMetadata(id pub.TrackID, metadata pub.TrackMetadata) { +func (t *Tracker) UpdatePublishedTrackMetadata(id track.TrackID, metadata track.TrackMetadata) { if track, found := t.publishedTracks[id]; found { track.SetMetadata(metadata) t.publishedTracks[id] = track @@ -150,7 +150,7 @@ func (t *Tracker) UpdatePublishedTrackMetadata(id pub.TrackID, metadata pub.Trac } // Informs the tracker that one of the previously published tracks is gone. -func (t *Tracker) RemovePublishedTrack(id pub.TrackID) { +func (t *Tracker) RemovePublishedTrack(id track.TrackID) { if publishedTrack, found := t.publishedTracks[id]; found { publishedTrack.Stop() delete(t.publishedTracks, id) @@ -158,7 +158,7 @@ func (t *Tracker) RemovePublishedTrack(id pub.TrackID) { } // Subscribes a given participant to the track. -func (t *Tracker) Subscribe(participantID ID, trackID pub.TrackID, requirements pub.TrackMetadata) error { +func (t *Tracker) Subscribe(participantID ID, trackID track.TrackID, requirements track.TrackMetadata) error { // Check if the participant exists that wants to subscribe exists. participant := t.participants[participantID] if participant == nil { @@ -180,7 +180,7 @@ func (t *Tracker) Subscribe(participantID ID, trackID pub.TrackID, requirements } // Unsubscribes a given `participantID` from the track. -func (t *Tracker) Unsubscribe(participantID ID, trackID pub.TrackID) { +func (t *Tracker) Unsubscribe(participantID ID, trackID track.TrackID) { if published := t.publishedTracks[trackID]; published != nil { published.Unsubscribe(participantID) } From bbe96307884a28e27aaffa476dcdb37a508d1462 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 22 Feb 2023 21:35:48 +0100 Subject: [PATCH 7/7] minor: make subscribed log into info level --- pkg/conference/peer_message_processing.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/conference/peer_message_processing.go b/pkg/conference/peer_message_processing.go index d30d888..c5e259c 100644 --- a/pkg/conference/peer_message_processing.go +++ b/pkg/conference/peer_message_processing.go @@ -167,7 +167,7 @@ func (c *Conference) processTrackSubscriptionMessage( continue } - p.Logger.Debugf("Subscribed to track %s", track.TrackID) + p.Logger.Infof("Subscribed to track %s", track.TrackID) } }