Skip to content

Commit

Permalink
Provide SCTP Association OnClose callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 12, 2024
1 parent c4d56d4 commit 6cfa00f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 4 deletions.
32 changes: 29 additions & 3 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type SCTPTransport struct {
// OnStateChange func()

onErrorHandler func(error)
onCloseHandler func(error)

sctpAssociation *sctp.Association
onDataChannelHandler func(*DataChannel)
Expand Down Expand Up @@ -176,6 +177,7 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
dataChannels = append(dataChannels, dc.dataChannel)
}
r.lock.RUnlock()

ACCEPT:
for {
dc, err := datachannel.Accept(a, &datachannel.Config{
Expand All @@ -185,6 +187,9 @@ ACCEPT:
if !errors.Is(err, io.EOF) {
r.log.Errorf("Failed to accept data channel: %v", err)
r.onError(err)
r.onClose(err)
} else {
r.onClose(nil)
}
return
}
Expand Down Expand Up @@ -232,9 +237,14 @@ ACCEPT:
MaxRetransmits: maxRetransmits,
}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
if err != nil {
// This data channel is invalid. Close it and log an error.
if err1 := dc.Close(); err1 != nil {
r.log.Errorf("Failed to close invalid data channel: %v", err1)
}
r.log.Errorf("Failed to accept data channel: %v", err)
r.onError(err)
return
// We've received a datachannel with invalid configuration. We can still receive other datachannels.
continue ACCEPT
}

<-r.onDataChannel(rtcDC)
Expand All @@ -251,8 +261,7 @@ ACCEPT:
}
}

// OnError sets an event handler which is invoked when
// the SCTP connection error occurs.
// OnError sets an event handler which is invoked when the SCTP Association errors.
func (r *SCTPTransport) OnError(f func(err error)) {
r.lock.Lock()
defer r.lock.Unlock()
Expand All @@ -269,6 +278,23 @@ func (r *SCTPTransport) onError(err error) {
}
}

// OnClose sets an event handler which is invoked when the SCTP Association closes.
func (r *SCTPTransport) OnClose(f func(err error)) {
r.lock.Lock()
defer r.lock.Unlock()
r.onCloseHandler = f
}

func (r *SCTPTransport) onClose(err error) {
r.lock.RLock()
handler := r.onCloseHandler
r.lock.RUnlock()

if handler != nil {
go handler(err)
}
}

// OnDataChannel sets an event handler which is invoked when a data
// channel message arrives from a remote peer.
func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
Expand Down
71 changes: 70 additions & 1 deletion sctptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

package webrtc

import "testing"
import (
"bytes"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestGenerateDataChannelID(t *testing.T) {
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
Expand Down Expand Up @@ -55,3 +61,66 @@ func TestGenerateDataChannelID(t *testing.T) {
}
}
}

func TestSCTPTransportOnClose(t *testing.T) {
offerPC, answerPC, err := newPair()
require.NoError(t, err)

answerPC.OnDataChannel(func(dc *DataChannel) {
dc.OnMessage(func(_ DataChannelMessage) {
if err1 := dc.Send([]byte("hello")); err1 != nil {
t.Error("failed to send message")
}
})
})

recvMsg := make(chan struct{}, 1)
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
defer func() {
offerPC.OnConnectionStateChange(nil)
}()

dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
if createErr != nil {
t.Errorf("Failed to create a PC pair for testing")
return
}
dc.OnMessage(func(msg DataChannelMessage) {
if !bytes.Equal(msg.Data, []byte("hello")) {
t.Error("invalid msg received")
}
recvMsg <- struct{}{}
})
dc.OnOpen(func() {
if err1 := dc.Send([]byte("hello")); err1 != nil {
t.Error("failed to send initial msg", err1)
}
})
}
})

err = signalPair(offerPC, answerPC)
require.NoError(t, err)

select {
case <-recvMsg:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}

// setup SCTP OnClose callback
ch := make(chan error, 1)
answerPC.SCTP().OnClose(func(err error) {
ch <- err
})

err = offerPC.Close() // This will trigger sctp onclose callback on remote
require.NoError(t, err)

select {
case <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}
}

0 comments on commit 6cfa00f

Please sign in to comment.