diff --git a/connection_manager.go b/connection_manager.go index aa7961ba..834a96a5 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -16,8 +16,8 @@ package ouroboros import "sync" -// ConnectionManagerErrorFunc is a function that takes a connection ID and an error -type ConnectionManagerErrorFunc func(int, error) +// ConnectionManagerConnClosedFunc is a function that takes a connection ID and an optional error +type ConnectionManagerConnClosedFunc func(int, error) // ConnectionManagerTag represents the various tags that can be associated with a host or connection type ConnectionManagerTag uint16 @@ -62,7 +62,7 @@ type ConnectionManager struct { } type ConnectionManagerConfig struct { - ErrorFunc ConnectionManagerErrorFunc + ConnClosedFunc ConnectionManagerConnClosedFunc } type ConnectionManagerHost struct { @@ -117,13 +117,9 @@ func (c *ConnectionManager) AddConnection(connId int, conn *Connection) { } c.connectionsMutex.Unlock() go func() { - err, ok := <-conn.ErrorChan() - if !ok { - // Connection has closed normally - return - } - // Call configured error callback func - c.config.ErrorFunc(connId, err) + err := <-conn.ErrorChan() + // Call configured connection closed callback func + c.config.ConnClosedFunc(connId, err) }() } diff --git a/connection_manager_test.go b/connection_manager_test.go index 76704250..b034ea80 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -51,7 +51,7 @@ func TestConnectionManagerConnError(t *testing.T) { doneChan := make(chan any) connManager := ouroboros.NewConnectionManager( ouroboros.ConnectionManagerConfig{ - ErrorFunc: func(connId int, err error) { + ConnClosedFunc: func(connId int, err error) { if connId != expectedConnId { t.Fatalf("did not receive error from expected connection: got %d, wanted %d", connId, expectedConnId) } @@ -96,3 +96,50 @@ func TestConnectionManagerConnError(t *testing.T) { t.Fatalf("did not receive error within timeout") } } + +func TestConnectionManagerConnClosed(t *testing.T) { + expectedConnId := 42 + doneChan := make(chan any) + connManager := ouroboros.NewConnectionManager( + ouroboros.ConnectionManagerConfig{ + ConnClosedFunc: func(connId int, err error) { + if connId != expectedConnId { + t.Fatalf("did not receive closed signal from expected connection: got %d, wanted %d", connId, expectedConnId) + } + if err != nil { + t.Fatalf("received unexpected error: %s", err) + } + close(doneChan) + }, + }, + ) + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ouroboros.WithKeepAlive(false), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + connManager.AddConnection(expectedConnId, oConn) + time.AfterFunc( + 1*time.Second, + func() { + oConn.Close() + }, + ) + select { + case <-doneChan: + return + case <-time.After(10 * time.Second): + t.Fatalf("did not receive error within timeout") + } +}