From 0ed2a1c1fe811ad36db62210bbb4b5045410ce07 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 3 Apr 2023 17:30:17 +0200 Subject: [PATCH] track: handle stalled published tracks So that when the layer on a published track stalled, we can change subscriptions to expect receiving packets from a different track layer (different simulcast layer) within the same published track. Solves https://github.com/matrix-org/waterfall/issues/131. --- pkg/conference/publisher/publisher.go | 16 +- pkg/conference/subscription/audio.go | 12 -- pkg/conference/subscription/subscription.go | 4 - pkg/conference/subscription/video.go | 149 ++++++---------- pkg/conference/track/internal.go | 152 ---------------- pkg/conference/track/keyframe.go | 42 ----- pkg/conference/track/publisher.go | 182 ++++++++++++++++++++ pkg/conference/track/subscription.go | 70 ++++++++ pkg/conference/track/track.go | 83 +++------ 9 files changed, 333 insertions(+), 377 deletions(-) delete mode 100644 pkg/conference/track/internal.go delete mode 100644 pkg/conference/track/keyframe.go create mode 100644 pkg/conference/track/publisher.go create mode 100644 pkg/conference/track/subscription.go diff --git a/pkg/conference/publisher/publisher.go b/pkg/conference/publisher/publisher.go index e30e0cf..5435485 100644 --- a/pkg/conference/publisher/publisher.go +++ b/pkg/conference/publisher/publisher.go @@ -8,7 +8,6 @@ import ( "github.com/pion/rtp" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) var ErrSubscriptionExists = errors.New("subscription already exists") @@ -91,20 +90,13 @@ func (p *Publisher) AddSubscription(subscriptions ...Subscription) { } } -func (p *Publisher) RemoveSubscription(subscription Subscription) { +func (p *Publisher) RemoveSubscription(subscription ...Subscription) { p.mu.Lock() defer p.mu.Unlock() - delete(p.subscriptions, subscription) -} - -func (p *Publisher) DrainSubscriptions() []Subscription { - p.mu.Lock() - defer p.mu.Unlock() - - subscriptions := maps.Keys(p.subscriptions) - maps.Clear(p.subscriptions) - return subscriptions + for _, s := range subscription { + delete(p.subscriptions, s) + } } func (p *Publisher) GetTrack() Track { diff --git a/pkg/conference/subscription/audio.go b/pkg/conference/subscription/audio.go index ceaab15..9e05226 100644 --- a/pkg/conference/subscription/audio.go +++ b/pkg/conference/subscription/audio.go @@ -5,7 +5,6 @@ import ( "fmt" "io" - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) @@ -42,17 +41,6 @@ func (s *AudioSubscription) WriteRTP(packet rtp.Packet) error { return fmt.Errorf("Bug: no write RTP logic for an audio subscription!") } -func (s *AudioSubscription) SwitchLayer(simulcast webrtc_ext.SimulcastLayer) { -} - -func (s *AudioSubscription) Simulcast() webrtc_ext.SimulcastLayer { - return webrtc_ext.SimulcastLayerNone -} - -func (s *AudioSubscription) UpdateMuteState(muted bool) { - // We don't have any business logic at the moment for audio subscriptions. -} - func (s *AudioSubscription) readRTCP() { // Read incoming RTCP packets. Before these packets are returned they are processed by interceptors. // For things like NACK this needs to be called. diff --git a/pkg/conference/subscription/subscription.go b/pkg/conference/subscription/subscription.go index 1664184..b53c5a8 100644 --- a/pkg/conference/subscription/subscription.go +++ b/pkg/conference/subscription/subscription.go @@ -1,7 +1,6 @@ package subscription import ( - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) @@ -9,9 +8,6 @@ import ( type Subscription interface { Unsubscribe() error WriteRTP(packet rtp.Packet) error - SwitchLayer(simulcast webrtc_ext.SimulcastLayer) - Simulcast() webrtc_ext.SimulcastLayer - UpdateMuteState(muted bool) } type SubscriptionController interface { diff --git a/pkg/conference/subscription/video.go b/pkg/conference/subscription/video.go index ed9695e..e95178c 100644 --- a/pkg/conference/subscription/video.go +++ b/pkg/conference/subscription/video.go @@ -15,69 +15,50 @@ import ( "github.com/pion/rtp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/attribute" ) -type RequestKeyFrameFn = func(simulcast webrtc_ext.SimulcastLayer) error - type VideoSubscription struct { rtpSender *webrtc.RTPSender info webrtc_ext.TrackInfo - currentLayer atomic.Int32 // atomic webrtc_ext.SimulcastLayer - muted atomic.Bool // we don't expect any RTP packets - stalled atomic.Bool // we do expect RTP packets, but haven't received for a while - - controller SubscriptionController - requestKeyFrameFn RequestKeyFrameFn - worker *worker.Worker[rtp.Packet] + controller SubscriptionController + worker *worker.Worker[rtp.Packet] + stopped atomic.Bool logger *logrus.Entry telemetry *telemetry.Telemetry } +type KeyFrameRequest struct{} + +// Creates a new video subscription. Returns a subscription along with a channel +// that informs the parent about key frame requests from the subscriptions. When the +// channel is closed, the subscription's go-routine is stopped. func NewVideoSubscription( info webrtc_ext.TrackInfo, - simulcast webrtc_ext.SimulcastLayer, - muted bool, controller SubscriptionController, - requestKeyFrameFn RequestKeyFrameFn, logger *logrus.Entry, telemetryBuilder *telemetry.ChildBuilder, -) (*VideoSubscription, error) { +) (*VideoSubscription, <-chan KeyFrameRequest, error) { // Create a new track. rtpTrack, err := webrtc.NewTrackLocalStaticRTP(info.Codec, info.TrackID, info.StreamID) if err != nil { - return nil, fmt.Errorf("Failed to create track: %s", err) + return nil, nil, fmt.Errorf("Failed to create track: %v", err) } rtpSender, err := controller.AddTrack(rtpTrack) if err != nil { - return nil, fmt.Errorf("Failed to add track: %s", err) + return nil, nil, fmt.Errorf("Failed to add track: %v", err) } - // Atomic version of the webrtc_ext.SimulcastLayer. - var currentLayer atomic.Int32 - currentLayer.Store(int32(simulcast)) - - // By default we assume that the track is not muted. - var mutedState atomic.Bool - mutedState.Store(muted) - - // Also, the track is not stalled by default. - var stalled atomic.Bool - // Create a subscription. subscription := &VideoSubscription{ rtpSender, info, - currentLayer, - mutedState, - stalled, controller, - requestKeyFrameFn, nil, + atomic.Bool{}, logger, telemetryBuilder.Create("VideoSubscription"), } @@ -90,100 +71,68 @@ func NewVideoSubscription( // Configure the worker for the subscription. workerConfig := worker.Config[rtp.Packet]{ - ChannelSize: 16, // We really don't need a large buffer here, just to account for spikes. - Timeout: 3 * time.Second, // When do we assume the subscription is stalled. - OnTimeout: func() { - // Not receiving RTP packets for 3 seconds can happen either if we're muted (not an error), - // or if the peer does not send any data (that's a problem that potentially means a freeze). - // Also, we don't want to execute this part if the subscription has already been marked as stalled. - if !subscription.muted.Load() && !subscription.stalled.Load() { - layer := webrtc_ext.SimulcastLayer(subscription.currentLayer.Load()) - logger.Warnf("No RTP on subscription to %s (%s) for 3 seconds", subscription.info.TrackID, layer) - subscription.telemetry.Fail(fmt.Errorf("No incoming RTP packets for 3 seconds on %s", layer)) - subscription.stalled.Store(true) - } - }, - OnTask: workerState.handlePacket, + ChannelSize: 16, // We really don't need a large buffer here, just to account for spikes. + Timeout: 1 * time.Hour, + OnTimeout: func() {}, + OnTask: workerState.handlePacket, } // Start a worker for the subscription and create a subsription. subscription.worker = worker.StartWorker(workerConfig) - // Start reading and forwarding RTCP packets. - go subscription.readRTCP() - - // Request a key frame, so that we can get it from the publisher right after subscription. - subscription.requestKeyFrame() - - subscription.telemetry.AddEvent("subscribed", attribute.String("layer", simulcast.String())) + // Start reading and forwarding RTCP packets goroutine. + ch := subscription.startReadRTCP() - return subscription, nil + return subscription, ch, nil } func (s *VideoSubscription) Unsubscribe() error { + if !s.stopped.CompareAndSwap(false, true) { + return fmt.Errorf("Already stopped") + } + s.worker.Stop() - s.logger.Infof("Unsubscribing from %s (%s)", s.info.TrackID, webrtc_ext.SimulcastLayer(s.currentLayer.Load())) + s.logger.Info("Unsubscribed") s.telemetry.End() return s.controller.RemoveTrack(s.rtpSender) } func (s *VideoSubscription) WriteRTP(packet rtp.Packet) error { - if s.stalled.CompareAndSwap(true, false) { - simulcast := webrtc_ext.SimulcastLayer(s.currentLayer.Load()) - s.logger.Infof("Recovered subscription to %s (%s)", s.info.TrackID, simulcast) - s.telemetry.AddEvent("subscription recovered") - } - // Send the packet to the worker. return s.worker.Send(packet) } -func (s *VideoSubscription) SwitchLayer(simulcast webrtc_ext.SimulcastLayer) { - s.logger.Infof("Switching layer on %s to %s", s.info.TrackID, simulcast) - s.telemetry.AddEvent("switching simulcast layer", attribute.String("layer", simulcast.String())) - s.currentLayer.Store(int32(simulcast)) - s.requestKeyFrameFn(simulcast) -} - -func (s *VideoSubscription) TrackInfo() webrtc_ext.TrackInfo { - return s.info -} - -func (s *VideoSubscription) Simulcast() webrtc_ext.SimulcastLayer { - return webrtc_ext.SimulcastLayer(s.currentLayer.Load()) -} - -func (s *VideoSubscription) UpdateMuteState(muted bool) { - s.muted.Store(muted) -} - // Read incoming RTCP packets. Before these packets are returned they are processed by interceptors. -func (s *VideoSubscription) readRTCP() { - for { - packets, _, err := s.rtpSender.ReadRTCP() - if err != nil { - if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { - layer := webrtc_ext.SimulcastLayer(s.currentLayer.Load()) - s.logger.Debugf("failed to read RTCP on track: %s (%s): %s", s.info.TrackID, layer, err) - s.telemetry.AddEvent("subscription stopped") - s.worker.Stop() - return +func (s *VideoSubscription) startReadRTCP() <-chan KeyFrameRequest { + ch := make(chan KeyFrameRequest) + + go func() { + defer close(ch) + defer s.Unsubscribe() + defer s.telemetry.AddEvent("Stopped") + defer s.logger.Info("Stopped") + + for { + packets, _, err := s.rtpSender.ReadRTCP() + if err != nil { + if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { + s.logger.Debugf("Failed to read RTCP: %v", err) + return + } } - } - // We only want to inform others about PLIs and FIRs. We skip the rest of the packets for now. - for _, packet := range packets { - switch packet.(type) { - // For simplicity we assume that any of the key frame requests is just a key frame request. - case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest: - s.requestKeyFrame() + // We only want to inform others about PLIs and FIRs. We skip the rest of the packets for now. + for _, packet := range packets { + switch packet.(type) { + // For simplicity we assume that any of the key frame requests is just a key frame request. + case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest: + ch <- KeyFrameRequest{} + } } } - } -} + }() -func (s *VideoSubscription) requestKeyFrame() { - s.requestKeyFrameFn(webrtc_ext.SimulcastLayer(s.currentLayer.Load())) + return ch } // Internal state of a worker that runs in its own goroutine. diff --git a/pkg/conference/track/internal.go b/pkg/conference/track/internal.go deleted file mode 100644 index 7de78b1..0000000 --- a/pkg/conference/track/internal.go +++ /dev/null @@ -1,152 +0,0 @@ -package track - -import ( - "fmt" - - "github.com/matrix-org/waterfall/pkg/conference/publisher" - "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/matrix-org/waterfall/pkg/worker" - "github.com/pion/webrtc/v3" - "go.opentelemetry.io/otel/attribute" -) - -type trackOwner[SubscriberID comparable] struct { - owner SubscriberID - requestKeyFrame func(track *webrtc.TrackRemote) error -} - -type audioTrack struct { - // The sink of this audio track packets. - outputTrack *webrtc.TrackLocalStaticRTP -} - -type videoTrack struct { - // Publishers of each video layer. - publishers map[webrtc_ext.SimulcastLayer]*publisher.Publisher - // Key frame request handler. - keyframeHandler *worker.Worker[webrtc_ext.SimulcastLayer] -} - -// Forward audio packets from the source track to the destination track. -func forward(sender *webrtc.TrackRemote, receiver *webrtc.TrackLocalStaticRTP, stop <-chan struct{}) error { - for { - // Read the data from the remote track. - packet, _, readErr := sender.ReadRTP() - if readErr != nil { - return readErr - } - - // Write the data to the local track. - if writeErr := receiver.WriteRTP(packet); writeErr != nil { - return writeErr - } - - // Check if we need to stop processing packets. - select { - case <-stop: - return nil - default: - } - } -} - -func (p *PublishedTrack[SubscriberID]) addVideoPublisher(track *webrtc.TrackRemote) { - // Detect simulcast layer of a publisher and create loggers and scoped telemetry. - simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) - layerTelemetry := p.telemetry.CreateChild("layer", attribute.String("layer", simulcast.String())) - layerLogger := p.logger.WithField("layer", simulcast.String()) - - // Create a new publisher for the track. - pub, statusCh := publisher.NewPublisher(&publisher.RemoteTrack{track}, p.stopPublishers, layerLogger) - p.video.publishers[simulcast] = pub - - // Observe the status of the publisher. - p.activePublishers.Add(1) - go func() { - // Once this go-routine is done, inform that this publisher is stopped. - defer p.activePublishers.Done() - defer layerTelemetry.End() - - // Observe publisher's status events. - for status := range statusCh { - switch status { - // Publisher is not active (no packets received for a while). - case publisher.StatusStalled: - p.mutex.Lock() - defer p.mutex.Unlock() - - // Let's check if we're muted. If we are, it's ok to not receive packets. - if p.metadata.Muted { - layerLogger.Info("Publisher is stalled but we're muted, ignoring") - layerTelemetry.AddEvent("muted") - continue - } - - // Otherwise, remove all subscriptions and switch them to the lowest layer if available. - // We assume that the lowest layer is the latest to fail (normally, lowest layer always - // receive packets even if other layers are stalled). - subscriptions := pub.DrainSubscriptions() - lowLayer := p.video.publishers[webrtc_ext.SimulcastLayerLow] - if lowLayer != nil { - layerLogger.Info("Publisher is stalled, switching to the lowest layer") - layerTelemetry.AddEvent("stalled, switched to the low layer") - lowLayer.AddSubscription(subscriptions...) - continue - } - - // Otherwise, we have no other layer to switch to. Bummer. - layerLogger.Warn("Publisher is stalled and we have no other layer to switch to") - layerTelemetry.Fail(fmt.Errorf("stalled")) - continue - - // Publisher is active again (new packets received). - case publisher.StatusRecovered: - // Currently, we don't have any actions when the publisher is recovered, i.e. we - // do not switch subscriptions that **used to be subscribed to this layer** back. - // But we may want to do it once we have congestion control and bandwidth allocation. - } - } - - p.mutex.Lock() - defer p.mutex.Unlock() - - // Remove the publisher once it's gone. - delete(p.video.publishers, simulcast) - - // Find any other available layer, so that we can switch subscriptions that lost their publisher - // to a new publisher (at least they'll get some data). - var ( - availableLayer webrtc_ext.SimulcastLayer - availablePublisher *publisher.Publisher - ) - for layer, pub := range p.video.publishers { - availableLayer = layer - availablePublisher = pub - break - } - - // Now iterate over all subscriptions and find those that are now lost due to the publisher being away. - for subID, sub := range p.subscriptions { - if sub.Simulcast() == simulcast { - // If there is some other publisher on a different layer, let's switch to it - if availablePublisher != nil { - sub.SwitchLayer(availableLayer) - pub.AddSubscription(sub) - } else { - // Otherwise, let's just remove the subscription. - sub.Unsubscribe() - delete(p.subscriptions, subID) - } - } - } - }() -} - -func (p *PublishedTrack[SubscriberID]) isClosed() bool { - select { - case <-p.done: - return true - default: - return false - } -} diff --git a/pkg/conference/track/keyframe.go b/pkg/conference/track/keyframe.go deleted file mode 100644 index 37fdaca..0000000 --- a/pkg/conference/track/keyframe.go +++ /dev/null @@ -1,42 +0,0 @@ -package track - -import ( - "fmt" - - "github.com/matrix-org/waterfall/pkg/conference/publisher" - "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/webrtc/v3" -) - -func (p *PublishedTrack[SubscriberID]) handleKeyFrameRequest(simulcast webrtc_ext.SimulcastLayer) error { - publisher := p.getPublisher(simulcast) - if publisher == nil { - return fmt.Errorf("publisher with simulcast %s not found", simulcast) - } - - track, err := extractRemoteTrack(publisher) - if err != nil { - return err - } - - return p.owner.requestKeyFrame(track) -} - -func (p *PublishedTrack[SubscriberID]) getPublisher(simulcast webrtc_ext.SimulcastLayer) *publisher.Publisher { - p.mutex.Lock() - defer p.mutex.Unlock() - - // Get the track that we need to request a key frame for. - return p.video.publishers[simulcast] -} - -func extractRemoteTrack(pub *publisher.Publisher) (*webrtc.TrackRemote, error) { - // Get the track that we need to request a key frame for. - track := pub.GetTrack() - remoteTrack, ok := track.(*publisher.RemoteTrack) - if !ok { - return nil, fmt.Errorf("not a remote track in publisher") - } - - return remoteTrack.Track, nil -} diff --git a/pkg/conference/track/publisher.go b/pkg/conference/track/publisher.go new file mode 100644 index 0000000..d76317f --- /dev/null +++ b/pkg/conference/track/publisher.go @@ -0,0 +1,182 @@ +package track + +import ( + "fmt" + + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/telemetry" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" +) + +type trackOwner[SubscriberID comparable] struct { + owner SubscriberID + requestKeyFrame func(track *webrtc.TrackRemote) error +} + +type audioTrack struct { + // The sink of this audio track packets. + outputTrack *webrtc.TrackLocalStaticRTP +} + +type videoTrack struct { + // Publishers of each video layer. + publishers map[webrtc_ext.SimulcastLayer]*publisher.Publisher +} + +// Forward audio packets from the source track to the destination track. +func forward(sender *webrtc.TrackRemote, receiver *webrtc.TrackLocalStaticRTP, stop <-chan struct{}) error { + for { + // Read the data from the remote track. + packet, _, readErr := sender.ReadRTP() + if readErr != nil { + return readErr + } + + // Write the data to the local track. + if writeErr := receiver.WriteRTP(packet); writeErr != nil { + return writeErr + } + + // Check if we need to stop processing packets. + select { + case <-stop: + return nil + default: + } + } +} + +func (p *PublishedTrack[SubscriberID]) addVideoPublisher(track *webrtc.TrackRemote) { + // Detect simulcast layer of a publisher and create loggers and scoped telemetry. + simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) + pubLogger := p.logger.WithField("layer", simulcast.String()) + pubTelemetry := p.telemetry.CreateChild("layer", attribute.String("layer", simulcast.String())) + + // Create a new publisher for the track. + pub, pubCh := publisher.NewPublisher(&publisher.RemoteTrack{track}, p.stopPublishers, pubLogger) + p.video.publishers[simulcast] = pub + + // Observe the status of the publisher. + p.activePublishers.Add(1) + go p.processPublisherEvents(pub, pubCh, simulcast, pubLogger, pubTelemetry) +} + +// Processes the events from a single publisher, i.e. a single track, i.e. a single layer. +func (p *PublishedTrack[SubscriberID]) processPublisherEvents( + pub *publisher.Publisher, + pubChannel <-chan publisher.Status, + pubLayer webrtc_ext.SimulcastLayer, + pubLogger *logrus.Entry, + pubTelemetry *telemetry.Telemetry, +) { + // Once this go-routine is done, inform that this publisher is stopped. + defer p.activePublishers.Done() + defer pubTelemetry.End() + + // Observe publisher's status events. + for status := range pubChannel { + switch status { + // Publisher is not active (no packets received for a while). + case publisher.StatusStalled: + p.mutex.Lock() + defer p.mutex.Unlock() + + // Let's check if we're muted. If we are, it's ok to not receive packets. + if p.metadata.Muted { + pubLogger.Info("Publisher is stalled but we're muted, ignoring") + pubTelemetry.AddEvent("muted") + continue + } + + // Otherwise, remove all subscriptions and switch them to the lowest layer if available. + // We assume that the lowest layer is the latest to fail (normally, lowest layer always + // receive packets even if other layers are stalled). + + subscriptionsMap := p.getSubscriptionByLayer(pubLayer) + subscriptions := []publisher.Subscription{} + for _, subscription := range subscriptionsMap { + subscriptions = append(subscriptions, subscription.subscription) + } + + pub.RemoveSubscription(subscriptions...) + + lowLayer := p.video.publishers[webrtc_ext.SimulcastLayerLow] + if lowLayer != nil { + pubLogger.Info("Publisher is stalled, switching to the lowest layer") + pubTelemetry.AddEvent("stalled, so subscriptions switched to the low layer") + lowLayer.AddSubscription(subscriptions...) + for _, subscription := range subscriptionsMap { + subscription.currentLayer = webrtc_ext.SimulcastLayerLow + } + continue + } + + // Otherwise, we have no other layer to switch to. Bummer. + pubLogger.Warn("Publisher is stalled and we have no other layer to switch to") + pubTelemetry.Fail(fmt.Errorf("stalled")) + for _, subscription := range subscriptionsMap { + subscription.currentLayer = webrtc_ext.SimulcastLayerNone + } + + // Publisher is active again (new packets received). + case publisher.StatusRecovered: + p.mutex.Lock() + defer p.mutex.Unlock() + + pubLogger.Info("Publisher is recovered") + pubTelemetry.AddEvent("recovered") + + // Iterate over active subscriptions that don't have any active publisher + // and assign them to this publisher. + for _, subscription := range p.subscriptions { + if subscription.currentLayer == webrtc_ext.SimulcastLayerNone { + subscription.currentLayer = pubLayer + pub.AddSubscription(subscription.subscription) + } + } + } + } + + pubTelemetry.AddEvent("stopped, removing dependent subscriptions") + + // If we got there, then the publisher is stopped. + p.mutex.Lock() + defer p.mutex.Unlock() + + // Remove the publisher once it's gone. + delete(p.video.publishers, pubLayer) + + // Now iterate over all subscriptions and find those that are now lost due to the publisher being away. + // It seems like normally when a single track or layer is gone, it's due to failure, so we don't switch + // to different layers here, but instead just remove the dependent subscriptions. + for subID, sub := range p.subscriptions { + if sub.currentLayer == pubLayer { + sub.subscription.Unsubscribe() + delete(p.subscriptions, subID) + } + } +} + +func (p *PublishedTrack[SubscriberID]) isClosed() bool { + select { + case <-p.done: + return true + default: + return false + } +} + +func (p *PublishedTrack[SubscriberID]) getSubscriptionByLayer( + layer webrtc_ext.SimulcastLayer, +) map[SubscriberID]*trackSubscription { + subscriptions := map[SubscriberID]*trackSubscription{} + for subID, sub := range p.subscriptions { + if sub.currentLayer == layer { + subscriptions[subID] = sub + } + } + return subscriptions +} diff --git a/pkg/conference/track/subscription.go b/pkg/conference/track/subscription.go new file mode 100644 index 0000000..2949ecd --- /dev/null +++ b/pkg/conference/track/subscription.go @@ -0,0 +1,70 @@ +package track + +import ( + "fmt" + + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/conference/subscription" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/pion/webrtc/v3" +) + +type trackSubscription struct { + subscription subscription.Subscription + currentLayer webrtc_ext.SimulcastLayer +} + +func (p *PublishedTrack[SubscriberID]) processSubscriptionEvents( + sub *trackSubscription, + events <-chan subscription.KeyFrameRequest, +) { + for range events { + if err := p.processKeyFrameRequest(sub); err != nil { + p.logger.WithError(err).Error("Failed to handle key frame request") + p.telemetry.AddError(err) + } + } + + // If we got there than the subscription has stoppped. Remove it from the list. + p.mutex.Lock() + defer p.mutex.Unlock() + + if publisher := p.video.publishers[sub.currentLayer]; publisher != nil { + publisher.RemoveSubscription(sub.subscription) + } + + for subscriberID, subscription := range p.subscriptions { + if subscription == sub { + delete(p.subscriptions, subscriberID) + break + } + } +} + +func (p *PublishedTrack[SubscriberID]) processKeyFrameRequest(sub *trackSubscription) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + publisher := p.video.publishers[sub.currentLayer] + if publisher == nil { + return fmt.Errorf("publisher with simulcast %s not found", sub.currentLayer) + } + + track, err := extractRemoteTrack(publisher) + if err != nil { + return err + } + + return p.owner.requestKeyFrame(track) +} + +func extractRemoteTrack(pub *publisher.Publisher) (*webrtc.TrackRemote, error) { + // Get the track that we need to request a key frame for. + track := pub.GetTrack() + remoteTrack, ok := track.(*publisher.RemoteTrack) + if !ok { + return nil, fmt.Errorf("not a remote track in publisher") + } + + return remoteTrack.Track, nil +} diff --git a/pkg/conference/track/track.go b/pkg/conference/track/track.go index c90d720..8e087bc 100644 --- a/pkg/conference/track/track.go +++ b/pkg/conference/track/track.go @@ -3,13 +3,11 @@ package track import ( "fmt" "sync" - "time" "github.com/matrix-org/waterfall/pkg/conference/publisher" "github.com/matrix-org/waterfall/pkg/conference/subscription" "github.com/matrix-org/waterfall/pkg/telemetry" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/matrix-org/waterfall/pkg/worker" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" @@ -37,7 +35,7 @@ type PublishedTrack[SubscriberID SubscriberIdentifier] struct { // We must protect the data with a mutex since we want the `PublishedTrack` to remain thread-safe. mutex sync.Mutex // Currently active subscriptions for this track. - subscriptions map[SubscriberID]subscription.Subscription + subscriptions map[SubscriberID]*trackSubscription // Audio track data. The content will be `nil` if it's not an audio track. audio *audioTrack // Video track. The content will be `nil` if it's not a video track. @@ -72,7 +70,7 @@ func NewPublishedTrack[SubscriberID SubscriberIdentifier]( info: webrtc_ext.TrackInfoFromTrack(track), telemetry: telemetry, owner: trackOwner[SubscriberID]{ownerID, requestKeyFrame}, - subscriptions: make(map[SubscriberID]subscription.Subscription), + subscriptions: make(map[SubscriberID]*trackSubscription), audio: &audioTrack{outputTrack: nil}, video: &videoTrack{publishers: make(map[webrtc_ext.SimulcastLayer]*publisher.Publisher)}, metadata: metadata, @@ -108,19 +106,6 @@ func NewPublishedTrack[SubscriberID SubscriberIdentifier]( }() case webrtc.RTPCodecTypeVideo: - // Configure and start a worker to process incoming key frame requests. - workerConfig := worker.Config[webrtc_ext.SimulcastLayer]{ - ChannelSize: 16, - Timeout: 1 * time.Hour, - OnTimeout: func() {}, - OnTask: func(simulcast webrtc_ext.SimulcastLayer) { - published.handleKeyFrameRequest(simulcast) - }, - } - - worker := worker.StartWorker(workerConfig) - published.video.keyframeHandler = worker - // Start video publisher. published.addVideoPublisher(track) } @@ -206,59 +191,51 @@ func (p *PublishedTrack[SubscriberID]) Subscribe( // If the subscription exists, let's see if we need to update it. if sub := p.subscriptions[subscriberID]; sub != nil { - currentLayer := sub.Simulcast() - // If we do, let's switch the layer. - if currentLayer != layer { - p.video.publishers[currentLayer].RemoveSubscription(sub) - sub.SwitchLayer(layer) - p.video.publishers[layer].AddSubscription(sub) + if sub.currentLayer != layer { + p.video.publishers[sub.currentLayer].RemoveSubscription(sub.subscription) + p.video.publishers[layer].AddSubscription(sub.subscription) + sub.currentLayer = layer } // Subsription is up-to-date, nothing to change. return nil } - var ( - sub subscription.Subscription - err error - ) - - // Subscription does not exist, so let's create it. - switch p.info.Kind { - case webrtc.RTPCodecTypeVideo: - handler := func(simulcast webrtc_ext.SimulcastLayer) error { - return p.video.keyframeHandler.Send(simulcast) + sub, ch, err := func() (subscription.Subscription, <-chan subscription.KeyFrameRequest, error) { + // Subscription does not exist, so let's create it. + switch p.info.Kind { + case webrtc.RTPCodecTypeVideo: + sub, ch, err := subscription.NewVideoSubscription( + p.info, + controller, + logger.WithField("track", p.info.TrackID), + p.telemetry.ChildBuilder(attribute.String("id", subscriberID.String())), + ) + return sub, ch, err + case webrtc.RTPCodecTypeAudio: + sub, err := subscription.NewAudioSubscription(p.audio.outputTrack, controller) + return sub, nil, err + default: + return nil, nil, fmt.Errorf("unsupported track kind: %v", p.info.Kind) } - sub, err = subscription.NewVideoSubscription( - p.info, - layer, - p.metadata.Muted, - controller, - handler, - logger, - p.telemetry.ChildBuilder(attribute.String("id", subscriberID.String())), - ) - case webrtc.RTPCodecTypeAudio: - sub, err = subscription.NewAudioSubscription(p.audio.outputTrack, controller) - } - - // If there was an error, let's return it. + }() if err != nil { p.telemetry.AddError(fmt.Errorf("failed to create subscription: %w", err)) return err } // Add the subscription to the list of subscriptions. - p.subscriptions[subscriberID] = sub + subscription := &trackSubscription{sub, layer} + p.subscriptions[subscriberID] = subscription // And if it's a video subscription, add it to the list of subscriptions that get the feed from the publisher. if p.info.Kind == webrtc.RTPCodecTypeVideo { p.video.publishers[layer].AddSubscription(sub) + go p.processSubscriptionEvents(subscription, ch) } p.logger.WithField("subscriber", subscriberID).WithField("layer", layer).Info("New subscription") - return nil } @@ -268,11 +245,11 @@ func (p *PublishedTrack[SubscriberID]) Unsubscribe(subscriberID SubscriberID) { defer p.mutex.Unlock() if sub := p.subscriptions[subscriberID]; sub != nil { - sub.Unsubscribe() + sub.subscription.Unsubscribe() delete(p.subscriptions, subscriberID) if p.info.Kind == webrtc.RTPCodecTypeVideo { - p.video.publishers[sub.Simulcast()].RemoveSubscription(sub) + p.video.publishers[sub.currentLayer].RemoveSubscription(sub.subscription) } } } @@ -299,8 +276,4 @@ func (p *PublishedTrack[SubscriberID]) SetMetadata(metadata TrackMetadata) { p.mutex.Lock() defer p.mutex.Unlock() p.metadata = metadata - - for _, sub := range p.subscriptions { - sub.UpdateMuteState(metadata.Muted) - } }