diff --git a/cmd/sfu/main.go b/cmd/sfu/main.go index 2a5d41f..823ab15 100644 --- a/cmd/sfu/main.go +++ b/cmd/sfu/main.go @@ -100,11 +100,18 @@ func main() { return } - // Create a router to route incoming To-Device messages to the right conference. - routerChannel := routing.NewRouter(matrixClient, connectionFactory, config.Conference) + // Create a channel which we'll use to send events to the router. + matrixEvents := make(chan *event.Event) + defer close(matrixEvents) + + // Start a router that will receive events from the matrix client and route them to the appropriate conference. + routing.StartRouter(matrixClient, connectionFactory, matrixEvents, config.Conference) // Start matrix client sync. This function will block until the sync fails. - matrixClient.RunSyncing(func(e *event.Event) { - routerChannel <- e - }) + if err := matrixClient.RunSync(func(e *event.Event) { matrixEvents <- e }); err != nil { + logrus.WithError(err).Fatal("matrix client sync failed") + return + } + + logrus.Info("SFU stopped") } diff --git a/pkg/.DS_Store b/pkg/.DS_Store new file mode 100644 index 0000000..b2d2818 Binary files /dev/null and b/pkg/.DS_Store differ diff --git a/pkg/channel/sink.go b/pkg/channel/sink.go new file mode 100644 index 0000000..5d9bffa --- /dev/null +++ b/pkg/channel/sink.go @@ -0,0 +1,85 @@ +package channel + +import ( + "errors" + "sync/atomic" +) + +var ErrSinkSealed = errors.New("The channel is sealed") + +// SinkWithSender is a helper struct that allows to send messages to a message sink. +// The SinkWithSender abstracts the message sink which has a certain sender, so that +// the sender does not have to be specified every time a message is sent. +// At the same it guarantees that the caller can't alter the `sender`, which means that +// the sender can't impersonate another sender (and we guarantee this on a compile-time). +type SinkWithSender[SenderType comparable, MessageType any] struct { + // The sender of the messages. This is useful for multiple-producer-single-consumer scenarios. + sender SenderType + // The message sink to which the messages are sent. + messageSink chan<- Message[SenderType, MessageType] + // A channel that is used to indicate that our channel is considered sealed. It's akin + // to a close indication without really closing the channel. We don't want to close + // the channel here since we know that the sink is shared between multiple producers, + // so we only disallow sending to the sink at this point. + sealed chan struct{} + // A "mutex" that is used to protect the act of closing `sealed`. + alreadySealed atomic.Bool +} + +// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases. +// Note that since the current implementation accepts a channel, it's **not responsible** for closing it. +func NewSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *SinkWithSender[S, M] { + return &SinkWithSender[S, M]{ + sender: sender, + messageSink: messageSink, + sealed: make(chan struct{}), + } +} + +// Sends a message to the message sink. Blocks if the sink is full! +func (s *SinkWithSender[S, M]) Send(message M) error { + if s.alreadySealed.Load() { + return ErrSinkSealed + } + + messageWithSender := Message[S, M]{ + Sender: s.sender, + Content: message, + } + + select { + case <-s.sealed: + return ErrSinkSealed + case s.messageSink <- messageWithSender: + return nil + } +} + +// Seals the channel, which means that no messages could be sent via this channel. +// Any attempt to send a message after `Seal()` returns will result in an error. +// Note that it does not mean (does not guarantee) that any existing senders that are +// waiting on the send to unblock won't send the message to the recipient (this case +// can happen if buffered channels are used). The existing senders will either unblock +// at this point and get an error that the channel is sealed or will unblock by sending +// the message to the recipient (should the recipient be ready to consume at this point). +func (s *SinkWithSender[S, M]) Seal() { + if !s.alreadySealed.CompareAndSwap(false, true) { + return + } + + select { + case <-s.sealed: + return + default: + close(s.sealed) + } +} + +// Messages that are sent from the peer to the conference in order to communicate with other peers. +// Since each peer is isolated from others, it can't influence the state of other peers directly. +type Message[SenderType comparable, MessageType any] struct { + // The sender of the message. + Sender SenderType + // The content of the message. + Content MessageType +} diff --git a/pkg/common/channel.go b/pkg/common/channel.go deleted file mode 100644 index 4be62af..0000000 --- a/pkg/common/channel.go +++ /dev/null @@ -1,70 +0,0 @@ -package common - -import "sync/atomic" - -// In Go, unbounded channel means something different than what it means in Rust. -// I.e. unlike Rust, "unbounded" in Go means that the channel has **no buffer**, -// meaning that each attempt to send will block the channel until the receiver -// reads it. Majority of primitives here in `waterfall` are designed under assumption -// that sending is not blocking. -const UnboundedChannelSize = 512 - -// Creates a new channel, returns two counterparts of it where one can only send and another can only receive. -// Unlike traditional Go channels, these allow the receiver to mark the channel as closed which would then fail -// to send any messages to the channel over `Send“. -func NewChannel[M any]() (Sender[M], Receiver[M]) { - channel := make(chan M, UnboundedChannelSize) - closed := &atomic.Bool{} - sender := Sender[M]{channel, closed} - receiver := Receiver[M]{channel, closed} - return sender, receiver -} - -// Sender counterpart of the channel. -type Sender[M any] struct { - // The channel itself. - channel chan<- M - // Atomic variable that indicates whether the channel is closed. - receiverClosed *atomic.Bool -} - -// Tries to send a message if the channel is not closed. -// Returns the message back if the channel is closed. -func (s *Sender[M]) Send(message M) *M { - if !s.receiverClosed.Load() { - s.channel <- message - return nil - } else { - return &message - } -} - -// The receiver counterpart of the channel. -type Receiver[M any] struct { - // The channel itself. It's public, so that we can combine it in `select` statements. - Channel <-chan M - // Atomic variable that indicates whether the channel is closed. - receiverClosed *atomic.Bool -} - -// Marks the channel as closed, which means that no messages could be sent via this channel. -// Any attempt to send a message would result in an error. This is similar to closing the -// channel except that we don't close the underlying channel (since in Go receivers can't -// close the channel). -// -// This function reads (in a non-blocking way) all pending messages until blocking. Otherwise, -// they will stay forver in a channel and get lost. -func (r *Receiver[M]) Close() []M { - r.receiverClosed.Store(true) - - messages := make([]M, 0) - for { - msg, ok := <-r.Channel - if !ok { - break - } - messages = append(messages, msg) - } - - return messages -} diff --git a/pkg/common/message_sink.go b/pkg/common/message_sink.go deleted file mode 100644 index 6853032..0000000 --- a/pkg/common/message_sink.go +++ /dev/null @@ -1,83 +0,0 @@ -package common - -import ( - "errors" - "sync/atomic" -) - -// MessageSink is a helper struct that allows to send messages to a message sink. -// The MessageSink abstracts the message sink which has a certain sender, so that -// the sender does not have to be specified every time a message is sent. -// At the same it guarantees that the caller can't alter the `sender`, which means that -// the sender can't impersonate another sender (and we guarantee this on a compile-time). -type MessageSink[SenderType comparable, MessageType any] struct { - // The sender of the messages. This is useful for multiple-producer-single-consumer scenarios. - sender SenderType - // The message sink to which the messages are sent. - messageSink chan<- Message[SenderType, MessageType] - // Atomic variable that indicates whether the message sink is sealed. - // Basically it means that **the current sender** (but not other senders) - // won't be able to send any more messages to the message sink. The difference - // between this and the channel being closed is that the closed channel is not - // available for writing for all senders. - sealed atomic.Bool -} - -// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases. -func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *MessageSink[S, M] { - return &MessageSink[S, M]{ - sender: sender, - messageSink: messageSink, - } -} - -// Sends a message to the message sink. Blocks if the sink is full! -func (s *MessageSink[S, M]) Send(message M) error { - return s.send(message, false) -} - -// Sends a message to the message sink. Does **not** block if the sink is full, returns an error instead. -func (s *MessageSink[S, M]) TrySend(message M) error { - return s.send(message, true) -} - -// Sends a message to the message sink. -func (s *MessageSink[S, M]) send(message M, nonBlocking bool) error { - if s.sealed.Load() { - return errors.New("The channel is sealed, you can't send any messages over it") - } - - messageWithSender := Message[S, M]{ - Sender: s.sender, - Content: message, - } - - if nonBlocking { - select { - case s.messageSink <- messageWithSender: - return nil - default: - return errors.New("The channel is full, can't send without blocking") - } - } else { - s.messageSink <- messageWithSender - return nil - } -} - -// Seals the channel, which means that no messages could be sent via this channel. -// Any attempt to send a message would result in an error. This is similar to closing the -// channel except that we don't close the underlying channel (since there might be other -// senders that may want to use it). -func (s *MessageSink[S, M]) Seal() { - s.sealed.Store(true) -} - -// Messages that are sent from the peer to the conference in order to communicate with other peers. -// Since each peer is isolated from others, it can't influence the state of other peers directly. -type Message[SenderType comparable, MessageType any] struct { - // The sender of the message. - Sender SenderType - // The content of the message. - Content MessageType -} diff --git a/pkg/conference/matrix_message_processing.go b/pkg/conference/matrix_message_processing.go index a1d548d..8931301 100644 --- a/pkg/conference/matrix_message_processing.go +++ b/pkg/conference/matrix_message_processing.go @@ -3,7 +3,7 @@ package conference import ( "time" - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" @@ -15,9 +15,8 @@ import ( type MessageContent interface{} type MatrixMessage struct { - Sender participant.ID - Content MessageContent - RawEvent *event.Event + Sender participant.ID + Content MessageContent } // New participant tries to join the conference. @@ -48,7 +47,7 @@ func (c *Conference) onNewParticipant(id participant.ID, inviteEvent *event.Call } sdpAnswer = answer } else { - messageSink := common.NewMessageSink(id, c.peerMessages) + messageSink := channel.NewSink(id, c.peerMessages) peerConnection, answer, err := peer.NewPeer(c.connectionFactory, inviteEvent.Offer.SDP, messageSink, logger) if err != nil { @@ -56,18 +55,16 @@ func (c *Conference) onNewParticipant(id participant.ID, inviteEvent *event.Call return err } - heartbeat := common.Heartbeat{ - Interval: time.Duration(c.config.HeartbeatConfig.Interval) * time.Second, - Timeout: time.Duration(c.config.HeartbeatConfig.Timeout) * time.Second, - SendPing: func() bool { - return p.SendDataChannelMessage(event.Event{ - Type: event.FocusCallPing, - Content: event.Content{}, - }) == nil - }, - OnTimeout: func() { - messageSink.Send(peer.LeftTheCall{event.CallHangupKeepAliveTimeout}) - }, + pingEvent := event.Event{ + Type: event.FocusCallPing, + Content: event.Content{}, + } + + heartbeat := participant.HeartbeatConfig{ + Interval: time.Duration(c.config.HeartbeatConfig.Interval) * time.Second, + Timeout: time.Duration(c.config.HeartbeatConfig.Timeout) * time.Second, + SendPing: func() bool { return p.SendDataChannelMessage(pingEvent) == nil }, + OnTimeout: func() { messageSink.Send(peer.LeftTheCall{event.CallHangupKeepAliveTimeout}) }, } p = &participant.Participant{ @@ -75,7 +72,7 @@ func (c *Conference) onNewParticipant(id participant.ID, inviteEvent *event.Call Peer: peerConnection, Logger: logger, RemoteSessionID: inviteEvent.SenderSessionID, - HeartbeatPong: heartbeat.Start(), + Pong: heartbeat.Start(), } c.tracker.AddParticipant(p) diff --git a/pkg/conference/matrix_worker.go b/pkg/conference/matrix_worker.go index b6d4b70..b6f0b37 100644 --- a/pkg/conference/matrix_worker.go +++ b/pkg/conference/matrix_worker.go @@ -3,19 +3,19 @@ package conference import ( "time" - "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/matrix-org/waterfall/pkg/worker" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/id" ) type matrixWorker struct { - worker *common.Worker[signaling.MatrixMessage] + worker *worker.Worker[signaling.MatrixMessage] deviceID id.DeviceID } func newMatrixWorker(handler signaling.MatrixSignaler) *matrixWorker { - workerConfig := common.WorkerConfig[signaling.MatrixMessage]{ + workerConfig := worker.Config[signaling.MatrixMessage]{ ChannelSize: 128, Timeout: time.Hour, OnTimeout: func() {}, @@ -23,7 +23,7 @@ func newMatrixWorker(handler signaling.MatrixSignaler) *matrixWorker { } matrixWorker := &matrixWorker{ - worker: common.StartWorker(workerConfig), + worker: worker.StartWorker(workerConfig), deviceID: handler.DeviceID(), } diff --git a/pkg/common/heartbeat.go b/pkg/conference/participant/heartbeat.go similarity index 85% rename from pkg/common/heartbeat.go rename to pkg/conference/participant/heartbeat.go index f2f3146..165aaae 100644 --- a/pkg/common/heartbeat.go +++ b/pkg/conference/participant/heartbeat.go @@ -1,4 +1,4 @@ -package common +package participant import ( "time" @@ -6,8 +6,8 @@ import ( type Pong struct{} -// Heartbeat defines the configuration for a heartbeat. -type Heartbeat struct { +// HeartbeatConfig defines the configuration for a heartbeat. +type HeartbeatConfig struct { // How often to send pings. Interval time.Duration // After which time to consider the communication stalled. @@ -23,8 +23,8 @@ type Heartbeat struct { // on `PongChannel` for `Timeout`. If no response is received within `Timeout`, `OnTimeout` is called. // The goroutine stops once the channel is closed or upon handling the `OnTimeout`. The returned channel // is what the caller should use to inform about the reception of a pong. -func (h *Heartbeat) Start() chan<- Pong { - pong := make(chan Pong, UnboundedChannelSize) +func (h *HeartbeatConfig) Start() chan<- Pong { + pong := make(chan Pong, 1) go func() { ticker := time.NewTicker(h.Interval) @@ -52,7 +52,7 @@ func (h *Heartbeat) Start() chan<- Pong { // Tries to send a ping message using `SendPing` and retry it if it fails. // Returns `true` if the ping was sent successfully. -func (h *Heartbeat) sendWithRetry() bool { +func (h *HeartbeatConfig) sendWithRetry() bool { const retries = 3 retryInterval := h.Timeout / retries diff --git a/pkg/conference/participant/participant.go b/pkg/conference/participant/participant.go index 808f978..0532ff5 100644 --- a/pkg/conference/participant/participant.go +++ b/pkg/conference/participant/participant.go @@ -3,7 +3,6 @@ package participant import ( "fmt" - "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" @@ -25,7 +24,7 @@ type Participant struct { Logger *logrus.Entry Peer *peer.Peer[ID] RemoteSessionID id.SessionID - HeartbeatPong chan<- common.Pong + Pong chan<- Pong } func (p *Participant) AsMatrixRecipient() signaling.MatrixRecipient { diff --git a/pkg/conference/participant/track.go b/pkg/conference/participant/track.go index dad4cd8..bd908c4 100644 --- a/pkg/conference/participant/track.go +++ b/pkg/conference/participant/track.go @@ -1,7 +1,7 @@ package participant import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" "golang.org/x/exp/slices" ) @@ -13,9 +13,9 @@ type PublishedTrack struct { // Owner of a published track. Owner ID // Info about the track. - Info common.TrackInfo + Info webrtc_ext.TrackInfo // Available simulcast Layers. - Layers []common.SimulcastLayer + Layers []webrtc_ext.SimulcastLayer // Track metadata. Metadata TrackMetadata // Output track (if any). I.e. a track that would contain all RTP packets @@ -24,11 +24,11 @@ type PublishedTrack struct { } // Calculate the layer that we can use based on the requirements passed as parameters and available layers. -func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) common.SimulcastLayer { +func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) webrtc_ext.SimulcastLayer { // Audio track. For them we don't have any simulcast. We also don't have any simulcast for video // if there was no simulcast enabled at all. if p.Info.Kind == webrtc.RTPCodecTypeAudio || len(p.Layers) == 0 { - return common.SimulcastLayerNone + return webrtc_ext.SimulcastLayerNone } // Video track. Calculate the optimal layer closest to the requested resolution. @@ -36,16 +36,16 @@ func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) co // Ideally, here we would need to send an error if the desired layer is not available, but we don't // have a way to do it. So we just return the closest available layer. - priority := []common.SimulcastLayer{ + priority := []webrtc_ext.SimulcastLayer{ desiredLayer, - common.SimulcastLayerMedium, - common.SimulcastLayerLow, - common.SimulcastLayerHigh, + webrtc_ext.SimulcastLayerMedium, + webrtc_ext.SimulcastLayerLow, + webrtc_ext.SimulcastLayerHigh, } // More Go boilerplate. for _, desiredLayer := range priority { - layerIndex := slices.IndexFunc(p.Layers, func(simulcast common.SimulcastLayer) bool { + layerIndex := slices.IndexFunc(p.Layers, func(simulcast webrtc_ext.SimulcastLayer) bool { return simulcast == desiredLayer }) @@ -56,7 +56,7 @@ func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) co // Actually this part will never be executed, because if we got to this point, // we know that we at least have one layer available. - return common.SimulcastLayerLow + return webrtc_ext.SimulcastLayerLow } // Metadata that we have received about this track from a user. @@ -69,21 +69,21 @@ type TrackMetadata struct { // maximum resolution that we can get from the user. We assume that a medium quality layer is half the size of // the video (**but not half of the resolution**). I.e. medium quality is high quality divided by 4. And low // quality is medium quality divided by 4 (which is the same as the high quality dividied by 16). -func calculateDesiredLayer(fullWidth, fullHeight int, desiredWidth, desiredHeight int) common.SimulcastLayer { +func calculateDesiredLayer(fullWidth, fullHeight int, desiredWidth, desiredHeight int) webrtc_ext.SimulcastLayer { // Calculate combined length of width and height for the full and desired size videos. fullSize := fullWidth + fullHeight desiredSize := desiredWidth + desiredHeight if fullSize == 0 || desiredSize == 0 { - return common.SimulcastLayerLow + return webrtc_ext.SimulcastLayerLow } // Determine which simulcast desiredLayer to subscribe to based on the requested resolution. if ratio := float32(fullSize) / float32(desiredSize); ratio <= 1 { - return common.SimulcastLayerHigh + return webrtc_ext.SimulcastLayerHigh } else if ratio <= 2 { - return common.SimulcastLayerMedium + return webrtc_ext.SimulcastLayerMedium } - return common.SimulcastLayerLow + return webrtc_ext.SimulcastLayerLow } diff --git a/pkg/conference/participant/track_test.go b/pkg/conference/participant/track_test.go index 612db43..3d8f9a1 100644 --- a/pkg/conference/participant/track_test.go +++ b/pkg/conference/participant/track_test.go @@ -3,25 +3,25 @@ package participant_test import ( "testing" - "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/conference/participant" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" ) func TestGetOptimalLayer(t *testing.T) { // Helper function for a quick an descriptive test case definition. - layers := func(layers ...common.SimulcastLayer) []common.SimulcastLayer { + layers := func(layers ...webrtc_ext.SimulcastLayer) []webrtc_ext.SimulcastLayer { return layers } // Shortcuts for easy and descriptive test case definition. - low, mid, high := common.SimulcastLayerLow, common.SimulcastLayerMedium, common.SimulcastLayerHigh + low, mid, high := webrtc_ext.SimulcastLayerLow, webrtc_ext.SimulcastLayerMedium, webrtc_ext.SimulcastLayerHigh cases := []struct { - availableLayers []common.SimulcastLayer + availableLayers []webrtc_ext.SimulcastLayer fullWidth, fullHeight int desiredWidth, desiredHeight int - expectedOptimalLayer common.SimulcastLayer + expectedOptimalLayer webrtc_ext.SimulcastLayer }{ {layers(low, mid, high), 1728, 1056, 878, 799, mid}, // Screen sharing (Dave's case). {layers(low, mid, high), 1920, 1080, 320, 240, low}, // max=1080p, desired=240p, result=low. @@ -44,7 +44,7 @@ func TestGetOptimalLayer(t *testing.T) { } mock := participant.PublishedTrack{ - Info: common.TrackInfo{ + Info: webrtc_ext.TrackInfo{ Kind: webrtc.RTPCodecTypeVideo, }, } @@ -63,13 +63,13 @@ func TestGetOptimalLayer(t *testing.T) { func TestGetOptimalLayerAudio(t *testing.T) { mock := participant.PublishedTrack{ - Info: common.TrackInfo{ + Info: webrtc_ext.TrackInfo{ Kind: webrtc.RTPCodecTypeAudio, }, } - mock.Layers = []common.SimulcastLayer{common.SimulcastLayerLow} - if mock.GetOptimalLayer(100, 100) != common.SimulcastLayerNone { + mock.Layers = []webrtc_ext.SimulcastLayer{webrtc_ext.SimulcastLayerLow} + if mock.GetOptimalLayer(100, 100) != webrtc_ext.SimulcastLayerNone { t.Fatal("Expected no simulcast layer for audio") } } diff --git a/pkg/conference/participant/tracker.go b/pkg/conference/participant/tracker.go index 40012de..0a3788e 100644 --- a/pkg/conference/participant/tracker.go +++ b/pkg/conference/participant/tracker.go @@ -1,10 +1,8 @@ package participant import ( - "fmt" - - "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/peer/subscription" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" @@ -63,7 +61,7 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // Terminate the participant and remove it from the list. participant.Peer.Terminate() - close(participant.HeartbeatPong) + close(participant.Pong) delete(t.participants, participantID) // Remove the participant's tracks from all participants who might have subscribed to them. @@ -93,16 +91,16 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // that has been published and that we must take into account from now on. func (t *Tracker) AddPublishedTrack( participantID ID, - info common.TrackInfo, - simulcast common.SimulcastLayer, + info webrtc_ext.TrackInfo, + simulcast webrtc_ext.SimulcastLayer, metadata TrackMetadata, outputTrack *webrtc.TrackLocalStaticRTP, ) { // If this is a new track, let's add it to the list of published and inform participants. track, found := t.publishedTracks[info.TrackID] if !found { - layers := []common.SimulcastLayer{} - if simulcast != common.SimulcastLayerNone { + layers := []webrtc_ext.SimulcastLayer{} + if simulcast != webrtc_ext.SimulcastLayerNone { layers = append(layers, simulcast) } @@ -118,8 +116,8 @@ func (t *Tracker) AddPublishedTrack( } // If it's just a new layer, let's add it to the list of layers of the existing published track. - fn := func(layer common.SimulcastLayer) bool { return layer == simulcast } - if simulcast != common.SimulcastLayerNone && slices.IndexFunc(track.Layers, fn) == -1 { + fn := func(layer webrtc_ext.SimulcastLayer) bool { return layer == simulcast } + if simulcast != webrtc_ext.SimulcastLayerNone && slices.IndexFunc(track.Layers, fn) == -1 { track.Layers = append(track.Layers, simulcast) t.publishedTracks[info.TrackID] = track } @@ -164,8 +162,8 @@ func (t *Tracker) RemovePublishedTrack(id TrackID) { } type SubscribeRequest struct { - common.TrackInfo - Simulcast common.SimulcastLayer + webrtc_ext.TrackInfo + Simulcast webrtc_ext.SimulcastLayer } // Subscribes a given participant to the tracks that are passed as a parameter. @@ -181,26 +179,34 @@ func (t *Tracker) Subscribe(participantID ID, requests []SubscribeRequest) { err error ) + published := t.FindPublishedTrack(request.TrackID) + if published == nil { + participant.Logger.Errorf("Can't subscribe to non-existent track %s", request.TrackID) + continue + } + switch request.Kind { case webrtc.RTPCodecTypeVideo: + owner := t.GetParticipant(published.Owner) + if owner == nil { + participant.Logger.Errorf("Can't subscribe to non-existent owner %s", published.Owner) + continue + } + sub, err = subscription.NewVideoSubscription( request.TrackInfo, request.Simulcast, participant.Peer, - func(track common.TrackInfo, simulcast common.SimulcastLayer) error { - return participant.Peer.RequestKeyFrame(track, simulcast) + func(track webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error { + return owner.Peer.RequestKeyFrame(track, simulcast) }, participant.Logger, ) case webrtc.RTPCodecTypeAudio: - if published := t.FindPublishedTrack(request.TrackID); published != nil { - sub, err = subscription.NewAudioSubscription( - published.OutputTrack, - participant.Peer, - ) - } else { - err = fmt.Errorf("Can't subscribe to non-existent track %s", request.TrackID) - } + sub, err = subscription.NewAudioSubscription( + published.OutputTrack, + participant.Peer, + ) } if err != nil { @@ -248,32 +254,12 @@ func (t *Tracker) Unsubscribe(participantID ID, tracks []TrackID) { } // Processes an RTP packet received on a given track. -func (t *Tracker) ProcessRTP(info common.TrackInfo, simulcast common.SimulcastLayer, packet *rtp.Packet) { - for participantID, subscription := range t.subscribers[info.TrackID] { +func (t *Tracker) ProcessRTP(info webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer, packet *rtp.Packet) { + for _, subscription := range t.subscribers[info.TrackID] { if subscription.Simulcast() == simulcast { if err := subscription.WriteRTP(*packet); err != nil { - if participant := t.GetParticipant(participantID); participant != nil { - participant.Logger.Errorf("Error writing RTP to %s (%s): %s", info.TrackID, simulcast, err) - continue - } - logrus.Errorf("Bug: subscription without subscriber") + logrus.Errorf("Dropping an RTP packet on %s (%s): %s", info.TrackID, simulcast, err) } } } } - -// Processes RTCP packets received on a given track. -func (t *Tracker) ProcessKeyFrameRequest(info common.TrackInfo, simulcast common.SimulcastLayer) error { - published, found := t.publishedTracks[info.TrackID] - if !found { - return fmt.Errorf("no such track: %s", info.TrackID) - } - - participant := t.GetParticipant(published.Owner) - if participant == nil { - return fmt.Errorf("no such participant: %s", published.Owner) - } - - // We don't want to send keyframes too often, so we'll send them only once in a while. - return participant.Peer.WritePLI(info, simulcast) -} diff --git a/pkg/conference/peer_message_processing.go b/pkg/conference/peer_message_processing.go index cca8c9f..22ccaed 100644 --- a/pkg/conference/peer_message_processing.go +++ b/pkg/conference/peer_message_processing.go @@ -1,7 +1,6 @@ package conference import ( - "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" @@ -147,12 +146,6 @@ func (c *Conference) processDataChannelAvailableMessage(sender participant.ID, m }) } -func (c *Conference) processKeyFrameRequest(msg peer.KeyFrameRequestReceived) { - if err := c.tracker.ProcessKeyFrameRequest(msg.TrackInfo, msg.SimulcastLayer); err != nil { - c.logger.Errorf("Failed to process RTCP on %s (%s): %s", msg.TrackID, msg.SimulcastLayer, err) - } -} - // Handle the `FocusEvent` from the DataChannel message. func (c *Conference) processTrackSubscriptionMessage( p *participant.Participant, @@ -249,10 +242,10 @@ func (c *Conference) processNegotiateMessage(p *participant.Participant, msg eve } func (c *Conference) processPongMessage(p *participant.Participant) { - // New heartbeat received (keep-alive message that is periodically sent by the remote peer). - // We need to update the last heartbeat time. If the peer is not active for too long, we will - // consider peer's connection as stalled and will close it. - p.HeartbeatPong <- common.Pong{} + select { + case p.Pong <- participant.Pong{}: + default: + } } func (c *Conference) processMetadataMessage( diff --git a/pkg/conference/processing.go b/pkg/conference/processing.go index 0b8a27c..eecfc26 100644 --- a/pkg/conference/processing.go +++ b/pkg/conference/processing.go @@ -1,7 +1,7 @@ package conference import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/peer" "maunium.net/go/mautrix/event" @@ -10,35 +10,29 @@ import ( // Listen on messages from incoming channels and process them. // This is essentially the main loop of the conference. // If this function returns, the conference is over. -func (c *Conference) processMessages() { +func (c *Conference) processMessages(signalDone chan struct{}) { + // When the main loop of the conference ends, clean up the resources. + defer close(signalDone) + defer c.matrixWorker.stop() + for { select { case msg := <-c.peerMessages: c.processPeerMessage(msg) - case msg := <-c.matrixMessages.Channel: + case msg := <-c.matrixEvents: c.processMatrixMessage(msg) } // If there are no more participants, stop the conference. if !c.tracker.HasParticipants() { c.logger.Info("No more participants, stopping the conference") - // Close the channel so that the sender can't push any messages. - unreadMessages := c.matrixMessages.Close() - - // Send the information that we ended to the owner and pass the message - // that we did not process (so that we don't drop it silently). - c.endNotifier.Notify(unreadMessages) - - // Stop the matrix worker. - c.matrixWorker.stop() - return } } } // Process a message from a local peer. -func (c *Conference) processPeerMessage(message common.Message[participant.ID, peer.MessageContent]) { +func (c *Conference) processPeerMessage(message channel.Message[participant.ID, peer.MessageContent]) { // Since Go does not support ADTs, we have to use a switch statement to // determine the actual type of the message. switch msg := message.Content.(type) { @@ -62,8 +56,6 @@ func (c *Conference) processPeerMessage(message common.Message[participant.ID, p c.processDataChannelMessage(message.Sender, msg) case peer.DataChannelAvailable: c.processDataChannelAvailableMessage(message.Sender, msg) - case peer.KeyFrameRequestReceived: - c.processKeyFrameRequest(msg) default: c.logger.Errorf("Unknown message type: %T", msg) } diff --git a/pkg/conference/start.go b/pkg/conference/start.go index 381cebf..cc749c3 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -17,7 +17,7 @@ limitations under the License. package conference import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" @@ -28,17 +28,16 @@ import ( ) // Starts a new conference or fails and returns an error. +// The conference ends when the last participant leaves. func StartConference( confID string, config Config, peerConnectionFactory *webrtc_ext.PeerConnectionFactory, signaling signaling.MatrixSignaler, - conferenceEndNotifier ConferenceEndNotifier, + matrixEvents <-chan MatrixMessage, userID id.UserID, inviteEvent *event.CallInviteEventContent, -) (*common.Sender[MatrixMessage], error) { - sender, receiver := common.NewChannel[MatrixMessage]() - +) (<-chan struct{}, error) { conference := &Conference{ id: confID, config: config, @@ -47,23 +46,18 @@ func StartConference( matrixWorker: newMatrixWorker(signaling), tracker: *participant.NewParticipantTracker(), streamsMetadata: make(event.CallSDPStreamMetadata), - endNotifier: conferenceEndNotifier, - peerMessages: make(chan common.Message[participant.ID, peer.MessageContent], common.UnboundedChannelSize), - matrixMessages: receiver, + peerMessages: make(chan channel.Message[participant.ID, peer.MessageContent], 100), + matrixEvents: matrixEvents, } participantID := participant.ID{UserID: userID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} if err := conference.onNewParticipant(participantID, inviteEvent); err != nil { - return nil, err + return nil, nil } // Start conference "main loop". - go conference.processMessages() - - return &sender, nil -} + signalDone := make(chan struct{}) + go conference.processMessages(signalDone) -type ConferenceEndNotifier interface { - // Called when the conference ends. - Notify(unread []MatrixMessage) + return signalDone, nil } diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 04a0fed..7fae3d7 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -1,7 +1,7 @@ package conference import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/webrtc_ext" @@ -11,10 +11,10 @@ import ( // A single conference. Call and conference mean the same in context of Matrix. type Conference struct { - id string - config Config - logger *logrus.Entry - endNotifier ConferenceEndNotifier + id string + config Config + logger *logrus.Entry + conferenceDone chan<- struct{} connectionFactory *webrtc_ext.PeerConnectionFactory matrixWorker *matrixWorker @@ -22,8 +22,8 @@ type Conference struct { tracker participant.Tracker streamsMetadata event.CallSDPStreamMetadata - peerMessages chan common.Message[participant.ID, peer.MessageContent] - matrixMessages common.Receiver[MatrixMessage] + peerMessages chan channel.Message[participant.ID, peer.MessageContent] + matrixEvents <-chan MatrixMessage } func (c *Conference) getParticipant(id participant.ID) *participant.Participant { diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 2064c49..4f67a6d 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,7 +1,7 @@ package peer import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" @@ -19,9 +19,9 @@ type LeftTheCall struct { type NewTrackPublished struct { // Information about the track (ID etc). - common.TrackInfo + webrtc_ext.TrackInfo // SimulcastLayer configuration (can be `None` for non-simulcast tracks and for audio tracks). - SimulcastLayer common.SimulcastLayer + SimulcastLayer webrtc_ext.SimulcastLayer // Output track (if any) that could be used to send data to the peer. Will be `nil` if such // track does not exist, in which case the caller is expected to listen to `RtpPacketReceived` // messages. @@ -29,13 +29,13 @@ type NewTrackPublished struct { } type PublishedTrackFailed struct { - common.TrackInfo - SimulcastLayer common.SimulcastLayer + webrtc_ext.TrackInfo + SimulcastLayer webrtc_ext.SimulcastLayer } type RTPPacketReceived struct { - common.TrackInfo - SimulcastLayer common.SimulcastLayer + webrtc_ext.TrackInfo + SimulcastLayer webrtc_ext.SimulcastLayer Packet *rtp.Packet } @@ -54,8 +54,3 @@ type DataChannelMessage struct { } type DataChannelAvailable struct{} - -type KeyFrameRequestReceived struct { - common.TrackInfo - SimulcastLayer common.SimulcastLayer -} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index d8efc51..0b20650 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/peer/state" "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtcp" @@ -31,7 +31,7 @@ var ( type Peer[ID comparable] struct { logger *logrus.Entry peerConnection *webrtc.PeerConnection - sink *common.MessageSink[ID, MessageContent] + sink *channel.SinkWithSender[ID, MessageContent] state *state.PeerState } @@ -39,7 +39,7 @@ type Peer[ID comparable] struct { func NewPeer[ID comparable]( connectionFactory *webrtc_ext.PeerConnectionFactory, sdpOffer string, - sink *common.MessageSink[ID, MessageContent], + sink *channel.SinkWithSender[ID, MessageContent], logger *logrus.Entry, ) (*Peer[ID], *webrtc.SessionDescription, error) { peerConnection, err := connectionFactory.CreatePeerConnection() @@ -83,8 +83,8 @@ func (p *Peer[ID]) Terminate() { p.sink.Seal() } -// Writes the specified packets to the `trackID`. -func (p *Peer[ID]) WritePLI(info common.TrackInfo, simulcast common.SimulcastLayer) error { +// Request a key frame from the peer connection. +func (p *Peer[ID]) RequestKeyFrame(info webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error { // Find the right track. track := p.state.GetRemoteTrack(info.TrackID, simulcast) if track == nil { @@ -171,7 +171,3 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return &answer, nil } - -func (p *Peer[ID]) RequestKeyFrame(info common.TrackInfo, simulcast common.SimulcastLayer) error { - return p.sink.TrySend(KeyFrameRequestReceived{info, simulcast}) -} diff --git a/pkg/peer/remote_track.go b/pkg/peer/remote_track.go index f86792a..97a82ae 100644 --- a/pkg/peer/remote_track.go +++ b/pkg/peer/remote_track.go @@ -4,19 +4,17 @@ import ( "errors" "io" - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) func (p *Peer[ID]) handleNewVideoTrack( - trackInfo common.TrackInfo, + trackInfo webrtc_ext.TrackInfo, remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, ) { - p.logger.Infof("ontrack got video track %s", trackInfo.TrackID) - - simulcast := common.RIDToSimulcastLayer(remoteTrack.RID()) + simulcast := webrtc_ext.RIDToSimulcastLayer(remoteTrack.RID()) p.handleRemoteTrack(remoteTrack, trackInfo, simulcast, nil, func(packet *rtp.Packet) error { p.sink.Send(RTPPacketReceived{trackInfo, simulcast, packet}) @@ -25,7 +23,7 @@ func (p *Peer[ID]) handleNewVideoTrack( } func (p *Peer[ID]) handleNewAudioTrack( - trackInfo common.TrackInfo, + trackInfo webrtc_ext.TrackInfo, remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, ) { @@ -41,7 +39,7 @@ func (p *Peer[ID]) handleNewAudioTrack( return } - p.handleRemoteTrack(remoteTrack, trackInfo, common.SimulcastLayerNone, localTrack, func(packet *rtp.Packet) error { + p.handleRemoteTrack(remoteTrack, trackInfo, webrtc_ext.SimulcastLayerNone, localTrack, func(packet *rtp.Packet) error { if err = localTrack.WriteRTP(packet); err != nil && !errors.Is(err, io.ErrClosedPipe) { return err } @@ -51,8 +49,8 @@ func (p *Peer[ID]) handleNewAudioTrack( func (p *Peer[ID]) handleRemoteTrack( remoteTrack *webrtc.TrackRemote, - trackInfo common.TrackInfo, - simulcast common.SimulcastLayer, + trackInfo webrtc_ext.TrackInfo, + simulcast webrtc_ext.SimulcastLayer, outputTrack *webrtc.TrackLocalStaticRTP, handleRtpFn func(*rtp.Packet) error, ) { diff --git a/pkg/peer/state/peer_state.go b/pkg/peer/state/peer_state.go index 45eac7d..b277139 100644 --- a/pkg/peer/state/peer_state.go +++ b/pkg/peer/state/peer_state.go @@ -3,13 +3,13 @@ package state import ( "sync" - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" ) type RemoteTrackId struct { id string - simulcast common.SimulcastLayer + simulcast webrtc_ext.SimulcastLayer } type PeerState struct { @@ -28,17 +28,17 @@ func (p *PeerState) AddRemoteTrack(track *webrtc.TrackRemote) { p.mutex.Lock() defer p.mutex.Unlock() - p.remoteTracks[RemoteTrackId{track.ID(), common.RIDToSimulcastLayer(track.RID())}] = track + p.remoteTracks[RemoteTrackId{track.ID(), webrtc_ext.RIDToSimulcastLayer(track.RID())}] = track } func (p *PeerState) RemoveRemoteTrack(track *webrtc.TrackRemote) { p.mutex.Lock() defer p.mutex.Unlock() - delete(p.remoteTracks, RemoteTrackId{track.ID(), common.RIDToSimulcastLayer(track.RID())}) + delete(p.remoteTracks, RemoteTrackId{track.ID(), webrtc_ext.RIDToSimulcastLayer(track.RID())}) } -func (p *PeerState) GetRemoteTrack(id string, simulcast common.SimulcastLayer) *webrtc.TrackRemote { +func (p *PeerState) GetRemoteTrack(id string, simulcast webrtc_ext.SimulcastLayer) *webrtc.TrackRemote { p.mutex.Lock() defer p.mutex.Unlock() diff --git a/pkg/peer/subscription/audio.go b/pkg/peer/subscription/audio.go index 7e99ede..c1b40c3 100644 --- a/pkg/peer/subscription/audio.go +++ b/pkg/peer/subscription/audio.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) @@ -42,11 +42,11 @@ func (s *AudioSubscription) WriteRTP(packet rtp.Packet) error { return fmt.Errorf("Bug: no write RTP logic for an audio subscription!") } -func (s *AudioSubscription) SwitchLayer(simulcast common.SimulcastLayer) { +func (s *AudioSubscription) SwitchLayer(simulcast webrtc_ext.SimulcastLayer) { } -func (s *AudioSubscription) Simulcast() common.SimulcastLayer { - return common.SimulcastLayerNone +func (s *AudioSubscription) Simulcast() webrtc_ext.SimulcastLayer { + return webrtc_ext.SimulcastLayerNone } func (s *AudioSubscription) readRTCP() { diff --git a/pkg/peer/subscription/subscription.go b/pkg/peer/subscription/subscription.go index deb8a04..bf6f6b4 100644 --- a/pkg/peer/subscription/subscription.go +++ b/pkg/peer/subscription/subscription.go @@ -1,7 +1,7 @@ package subscription import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) @@ -9,8 +9,8 @@ import ( type Subscription interface { Unsubscribe() error WriteRTP(packet rtp.Packet) error - SwitchLayer(simulcast common.SimulcastLayer) - Simulcast() common.SimulcastLayer + SwitchLayer(simulcast webrtc_ext.SimulcastLayer) + Simulcast() webrtc_ext.SimulcastLayer } type SubscriptionController interface { diff --git a/pkg/peer/subscription/video.go b/pkg/peer/subscription/video.go index babb02a..ea945c4 100644 --- a/pkg/peer/subscription/video.go +++ b/pkg/peer/subscription/video.go @@ -7,31 +7,32 @@ import ( "sync/atomic" "time" - "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/peer/subscription/rewriter" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/matrix-org/waterfall/pkg/worker" "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" ) -type RequestKeyFrameFn = func(track common.TrackInfo, simulcast common.SimulcastLayer) error +type RequestKeyFrameFn = func(track webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error type VideoSubscription struct { rtpSender *webrtc.RTPSender - info common.TrackInfo - currentLayer atomic.Int32 // atomic common.SimulcastLayer + info webrtc_ext.TrackInfo + currentLayer atomic.Int32 // atomic webrtc_ext.SimulcastLayer controller SubscriptionController requestKeyFrameFn RequestKeyFrameFn - worker *common.Worker[rtp.Packet] + worker *worker.Worker[rtp.Packet] logger *logrus.Entry } func NewVideoSubscription( - info common.TrackInfo, - simulcast common.SimulcastLayer, + info webrtc_ext.TrackInfo, + simulcast webrtc_ext.SimulcastLayer, controller SubscriptionController, requestKeyFrameFn RequestKeyFrameFn, logger *logrus.Entry, @@ -47,7 +48,7 @@ func NewVideoSubscription( return nil, fmt.Errorf("Failed to add track: %s", err) } - // Atomic version of the common.SimulcastLayer. + // Atomic version of the webrtc_ext.SimulcastLayer. var currentLayer atomic.Int32 currentLayer.Store(int32(simulcast)) @@ -69,11 +70,11 @@ func NewVideoSubscription( } // Configure the worker for the subscription. - workerConfig := common.WorkerConfig[rtp.Packet]{ - ChannelSize: 100, // Approx. 500ms of buffer size, we don't need more - Timeout: 2 * time.Second, + workerConfig := worker.Config[rtp.Packet]{ + ChannelSize: 32, + Timeout: 3 * time.Second, OnTimeout: func() { - layer := common.SimulcastLayer(subscription.currentLayer.Load()) + layer := webrtc_ext.SimulcastLayer(subscription.currentLayer.Load()) logger.Warnf("No RTP on subscription %s (%s)", subscription.info.TrackID, layer) subscription.requestKeyFrame() }, @@ -81,7 +82,7 @@ func NewVideoSubscription( } // Start a worker for the subscription and create a subsription. - subscription.worker = common.StartWorker(workerConfig) + subscription.worker = worker.StartWorker(workerConfig) // Start reading and forwarding RTCP packets. go subscription.readRTCP() @@ -94,7 +95,7 @@ func NewVideoSubscription( func (s *VideoSubscription) Unsubscribe() error { s.worker.Stop() - s.logger.Infof("Unsubscribing from %s (%s)", s.info.TrackID, common.SimulcastLayer(s.currentLayer.Load())) + s.logger.Infof("Unsubscribing from %s (%s)", s.info.TrackID, webrtc_ext.SimulcastLayer(s.currentLayer.Load())) return s.controller.RemoveTrack(s.rtpSender) } @@ -103,18 +104,18 @@ func (s *VideoSubscription) WriteRTP(packet rtp.Packet) error { return s.worker.Send(packet) } -func (s *VideoSubscription) SwitchLayer(simulcast common.SimulcastLayer) { +func (s *VideoSubscription) SwitchLayer(simulcast webrtc_ext.SimulcastLayer) { s.logger.Infof("Switching layer on %s to %s", s.info.TrackID, simulcast) s.currentLayer.Store(int32(simulcast)) s.requestKeyFrame() } -func (s *VideoSubscription) TrackInfo() common.TrackInfo { +func (s *VideoSubscription) TrackInfo() webrtc_ext.TrackInfo { return s.info } -func (s *VideoSubscription) Simulcast() common.SimulcastLayer { - return common.SimulcastLayer(s.currentLayer.Load()) +func (s *VideoSubscription) Simulcast() webrtc_ext.SimulcastLayer { + return webrtc_ext.SimulcastLayer(s.currentLayer.Load()) } // Read incoming RTCP packets. Before these packets are returned they are processed by interceptors. @@ -123,7 +124,7 @@ func (s *VideoSubscription) readRTCP() { packets, _, err := s.rtpSender.ReadRTCP() if err != nil { if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { - layer := common.SimulcastLayer(s.currentLayer.Load()) + layer := webrtc_ext.SimulcastLayer(s.currentLayer.Load()) s.logger.Warnf("failed to read RTCP on track: %s (%s): %s", s.info.TrackID, layer, err) s.worker.Stop() return @@ -142,7 +143,7 @@ func (s *VideoSubscription) readRTCP() { } func (s *VideoSubscription) requestKeyFrame() { - layer := common.SimulcastLayer(s.currentLayer.Load()) + layer := webrtc_ext.SimulcastLayer(s.currentLayer.Load()) if err := s.requestKeyFrameFn(s.info, layer); err != nil { s.logger.Errorf("Failed to request key frame: %s", err) } diff --git a/pkg/peer/webrtc_callbacks.go b/pkg/peer/webrtc_callbacks.go index ff8e54c..31288b8 100644 --- a/pkg/peer/webrtc_callbacks.go +++ b/pkg/peer/webrtc_callbacks.go @@ -1,7 +1,7 @@ package peer import ( - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -10,7 +10,7 @@ import ( // we call this function each time a new track is received. func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { // Construct a new track info assuming that there is no simulcast. - trackInfo := common.TrackInfoFromTrack(remoteTrack) + trackInfo := webrtc_ext.TrackInfoFromTrack(remoteTrack) switch trackInfo.Kind { case webrtc.RTPCodecTypeVideo: diff --git a/pkg/routing/router.go b/pkg/routing/router.go index e568ff2..6c81f43 100644 --- a/pkg/routing/router.go +++ b/pkg/routing/router.go @@ -17,7 +17,6 @@ limitations under the License. package routing import ( - "github.com/matrix-org/waterfall/pkg/common" conf "github.com/matrix-org/waterfall/pkg/conference" "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/signaling" @@ -27,58 +26,43 @@ import ( "maunium.net/go/mautrix/id" ) -type Conference = common.Sender[conf.MatrixMessage] - // The top-level state of the Router. type Router struct { // Matrix matrix. matrix *signaling.MatrixClient // Sinks of all conferences (all calls that are currently forwarded by this SFU). - conferenceSinks map[string]*Conference + conferenceSinks map[string]*conferenceStage // Configuration for the calls. config conf.Config - // A channel to serialize all incoming events to the Router. - channel chan RouterMessage + // Channel for reading incoming Matrix SDK To-Device events and distributing them to the conferences. + matrixEvents <-chan *event.Event + // Channel for handling conference ended events. // Peer connection factory that can be used to create pre-configured peer connections. connectionFactory *webrtc_ext.PeerConnectionFactory } // Creates a new instance of the SFU with the given configuration. -func NewRouter( +func StartRouter( matrix *signaling.MatrixClient, connectionFactory *webrtc_ext.PeerConnectionFactory, + matrixEvents <-chan *event.Event, config conf.Config, -) chan<- RouterMessage { +) { router := &Router{ matrix: matrix, - conferenceSinks: make(map[string]*common.Sender[conf.MatrixMessage]), + conferenceSinks: make(map[string]*conferenceStage), config: config, - channel: make(chan RouterMessage, common.UnboundedChannelSize), + matrixEvents: matrixEvents, connectionFactory: connectionFactory, } // Start the main loop of the Router. go func() { - for msg := range router.channel { - switch msg := msg.(type) { + for msg := range router.matrixEvents { // To-Device message received from the remote peer. - case MatrixMessage: - router.handleMatrixEvent(msg) - // One of the conferences has ended. - case ConferenceEndedMessage: - // Remove the conference that ended from the list. - delete(router.conferenceSinks, msg.conferenceID) - - // Process the message that was not read by the conference. - for _, msg := range msg.unread { - // TODO: We actually already know the type, so we can do this better. - router.handleMatrixEvent(msg.RawEvent) - } - } + router.handleMatrixEvent(msg) } }() - - return router.channel } // Handles incoming To-Device events that the SFU receives from clients. @@ -120,12 +104,15 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // are expected to operate on an existing conference that is running on the SFU. if conference == nil && evt.Type.Type == event.ToDeviceCallInvite.Type { logger.Infof("creating new conference %s", conferenceID) - conferenceSink, err := conf.StartConference( + + matrixEvents := make(chan conf.MatrixMessage) + + conferenceDone, err := conf.StartConference( conferenceID, r.config, r.connectionFactory, r.matrix.CreateForConference(conferenceID), - createConferenceEndNotifier(conferenceID, r.channel), + matrixEvents, userID, evt.Content.AsCallInvite(), ) @@ -134,7 +121,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - r.conferenceSinks[conferenceID] = conferenceSink + r.conferenceSinks[conferenceID] = &conferenceStage{matrixEvents, conferenceDone} return } @@ -144,70 +131,45 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - // A helper function to deal with messages that can't be sent due to the conference closed. - // Not a function due to the need to capture local environment. - sendToConference := func(eventContent conf.MessageContent) { - sender := participant.ID{userID, id.DeviceID(deviceID), callID} - // At this point the conference is not nil. - // Let's check if the channel is still available. - if conference.Send(conf.MatrixMessage{Content: eventContent, RawEvent: evt, Sender: sender}) != nil { - // If sending failed, then the conference is over. - delete(r.conferenceSinks, conferenceID) - // Since we were not able to send the message, let's re-process it now. - // Note, we probably do not want to block here! - r.handleMatrixEvent(evt) - } - } + // Sender of the To-Device message. + sender := participant.ID{userID, id.DeviceID(deviceID), callID} + var content conf.MessageContent switch evt.Type.Type { // Someone tries to participate in a call (join a call). case event.ToDeviceCallInvite.Type: // If there is an invitation sent and the conference does not exist, create one. - sendToConference(evt.Content.AsCallInvite()) + content = evt.Content.AsCallInvite() case event.ToDeviceCallCandidates.Type: // Someone tries to send ICE candidates to the existing call. - sendToConference(evt.Content.AsCallCandidates()) + content = evt.Content.AsCallCandidates() case event.ToDeviceCallSelectAnswer.Type: // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. - sendToConference(evt.Content.AsCallSelectAnswer()) + content = evt.Content.AsCallSelectAnswer() case event.ToDeviceCallHangup.Type: // Someone tries to inform us about leaving an existing call. - sendToConference(evt.Content.AsCallHangup()) + content = evt.Content.AsCallHangup() default: logger.Warnf("ignoring event that we must not receive: %s", evt.Type.Type) + return } -} -type RouterMessage = interface{} - -type MatrixMessage = *event.Event - -// Message that is sent from the conference when the conference is ended. -type ConferenceEndedMessage struct { - // The ID of the conference that has ended. - conferenceID string - // A message (or messages in future) that has not been processed (if any). - unread []conf.MatrixMessage -} - -// A simple wrapper around channel that contains the ID of the conference that sent the message. -type ConferenceEndNotifier struct { - conferenceID string - channel chan<- interface{} -} - -// Crates a simple notifier with a conference with a given ID. -func createConferenceEndNotifier(conferenceID string, channel chan<- RouterMessage) *ConferenceEndNotifier { - return &ConferenceEndNotifier{ - conferenceID: conferenceID, - channel: channel, + // Send the message to the conference. + select { + case <-conference.done: + // Conference has just gotten closed, let's remove it from the list of conferences. + delete(r.conferenceSinks, conferenceID) + close(conference.sink) + + // Since we were not able to send the message, let's re-process it now. + r.handleMatrixEvent(evt) + case conference.sink <- conf.MatrixMessage{Content: content, Sender: sender}: + // Ok,sent! + return } } -// A function that a conference calls when it is ended. -func (c *ConferenceEndNotifier) Notify(unread []conf.MatrixMessage) { - c.channel <- ConferenceEndedMessage{ - conferenceID: c.conferenceID, - unread: unread, - } +type conferenceStage struct { + sink chan<- conf.MatrixMessage + done <-chan struct{} } diff --git a/pkg/signaling/client.go b/pkg/signaling/client.go index 044d496..47f4a5e 100644 --- a/pkg/signaling/client.go +++ b/pkg/signaling/client.go @@ -1,6 +1,8 @@ package signaling import ( + "fmt" + "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -34,11 +36,11 @@ func NewMatrixClient(config Config) *MatrixClient { } // Starts the Matrix client and connects to the homeserver, -// Returns only when the sync with Matrix fails. -func (m *MatrixClient) RunSyncing(callback func(*event.Event)) { +// Returns only when the sync with Matrix stops or fails. +func (m *MatrixClient) RunSync(callback func(*event.Event)) error { syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) if !ok { - logrus.Panic("Syncer is not DefaultSyncer") + return fmt.Errorf("syncer is not a DefaultSyncer") } syncer.ParseEventContent = true @@ -61,7 +63,5 @@ func (m *MatrixClient) RunSyncing(callback func(*event.Event)) { // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU // as ending here will essentially drop all conferences which may not necessarily // be what we want for the existing running conferences. - if err := m.client.Sync(); err != nil { - logrus.WithError(err).Panic("Sync failed") - } + return m.client.Sync() } diff --git a/pkg/common/track_info.go b/pkg/webrtc_ext/track_info.go similarity index 98% rename from pkg/common/track_info.go rename to pkg/webrtc_ext/track_info.go index c00c123..03b1e60 100644 --- a/pkg/common/track_info.go +++ b/pkg/webrtc_ext/track_info.go @@ -1,4 +1,4 @@ -package common +package webrtc_ext import ( "github.com/pion/webrtc/v3" diff --git a/pkg/common/worker.go b/pkg/worker/worker.go similarity index 95% rename from pkg/common/worker.go rename to pkg/worker/worker.go index 81fbadc..3133457 100644 --- a/pkg/common/worker.go +++ b/pkg/worker/worker.go @@ -1,4 +1,4 @@ -package common +package worker import ( "errors" @@ -13,7 +13,7 @@ var ( ) // Configuration for the worker. -type WorkerConfig[T any] struct { +type Config[T any] struct { // The size of the bounded channel. ChannelSize int // Timeout after which `OnTimeout` is called. @@ -68,7 +68,7 @@ func (c *Worker[T]) Send(task T) error { // Starts a worker that periodically (specified by the configuration) executes a `c.OnTimeout` closure if // no tasks have been received on a channel for a `c.Timeout`. The worker will stop once the channel is closed, // i.e. once the user calls `Stop` explicitly. -func StartWorker[T any](c WorkerConfig[T]) *Worker[T] { +func StartWorker[T any](c Config[T]) *Worker[T] { // The channel that will be used to inform the worker about the reception of a task. // The worker will be stopped once the channel is closed. incoming := make(chan T, c.ChannelSize) diff --git a/pkg/common/worker_test.go b/pkg/worker/worker_test.go similarity index 53% rename from pkg/common/worker_test.go rename to pkg/worker/worker_test.go index a765d41..7e90eb7 100644 --- a/pkg/common/worker_test.go +++ b/pkg/worker/worker_test.go @@ -1,20 +1,20 @@ -package common_test +package worker_test import ( "testing" "time" - "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/worker" ) func BenchmarkWorker(b *testing.B) { - workerConfig := common.WorkerConfig[struct{}]{ - ChannelSize: common.UnboundedChannelSize, + workerConfig := worker.Config[struct{}]{ + ChannelSize: 1, Timeout: 2 * time.Second, OnTimeout: func() {}, OnTask: func(struct{}) {}, } - w := common.StartWorker(workerConfig) + w := worker.StartWorker(workerConfig) for n := 0; n < b.N; n++ { w.Send(struct{}{})