diff --git a/connection_manager.go b/connection_manager.go index 834a96a5..e5fe43dc 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -17,7 +17,7 @@ package ouroboros import "sync" // ConnectionManagerConnClosedFunc is a function that takes a connection ID and an optional error -type ConnectionManagerConnClosedFunc func(int, error) +type ConnectionManagerConnClosedFunc func(ConnectionId, error) // ConnectionManagerTag represents the various tags that can be associated with a host or connection type ConnectionManagerTag uint16 @@ -57,7 +57,7 @@ func (c ConnectionManagerTag) String() string { type ConnectionManager struct { config ConnectionManagerConfig hosts []ConnectionManagerHost - connections map[int]*ConnectionManagerConnection + connections map[ConnectionId]*ConnectionManagerConnection connectionsMutex sync.Mutex } @@ -74,7 +74,7 @@ type ConnectionManagerHost struct { func NewConnectionManager(cfg ConnectionManagerConfig) *ConnectionManager { return &ConnectionManager{ config: cfg, - connections: make(map[int]*ConnectionManagerConnection), + connections: make(map[ConnectionId]*ConnectionManagerConnection), } } @@ -109,10 +109,10 @@ func (c *ConnectionManager) AddHostsFromTopology(topology *TopologyConfig) { } } -func (c *ConnectionManager) AddConnection(connId int, conn *Connection) { +func (c *ConnectionManager) AddConnection(conn *Connection) { + connId := conn.Id() c.connectionsMutex.Lock() c.connections[connId] = &ConnectionManagerConnection{ - Id: connId, Conn: conn, } c.connectionsMutex.Unlock() @@ -123,13 +123,13 @@ func (c *ConnectionManager) AddConnection(connId int, conn *Connection) { }() } -func (c *ConnectionManager) RemoveConnection(connId int) { +func (c *ConnectionManager) RemoveConnection(connId ConnectionId) { c.connectionsMutex.Lock() delete(c.connections, connId) c.connectionsMutex.Unlock() } -func (c *ConnectionManager) GetConnectionById(connId int) *ConnectionManagerConnection { +func (c *ConnectionManager) GetConnectionById(connId ConnectionId) *ConnectionManagerConnection { c.connectionsMutex.Lock() defer c.connectionsMutex.Unlock() return c.connections[connId] @@ -155,7 +155,6 @@ func (c *ConnectionManager) GetConnectionsByTags(tags ...ConnectionManagerTag) [ } type ConnectionManagerConnection struct { - Id int Conn *Connection Tags map[ConnectionManagerTag]bool } diff --git a/connection_manager_test.go b/connection_manager_test.go index 66714452..71619ceb 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -48,12 +48,12 @@ func TestConnectionManagerTagString(t *testing.T) { func TestConnectionManagerConnError(t *testing.T) { defer goleak.VerifyNone(t) - expectedConnId := 2 + var expectedConnId ouroboros.ConnectionId expectedErr := io.EOF doneChan := make(chan any) connManager := ouroboros.NewConnectionManager( ouroboros.ConnectionManagerConfig{ - ConnClosedFunc: func(connId int, err error) { + ConnClosedFunc: func(connId ouroboros.ConnectionId, err error) { if err != nil { if connId != expectedConnId { t.Fatalf("did not receive error from expected connection: got %d, wanted %d", connId, expectedConnId) @@ -66,9 +66,10 @@ func TestConnectionManagerConnError(t *testing.T) { }, }, ) + testIdx := 2 for i := 0; i < 3; i++ { mockConversation := ouroboros_mock.ConversationKeepAlive - if i == expectedConnId { + if i == testIdx { mockConversation = ouroboros_mock.ConversationKeepAliveClose } mockConn := ouroboros_mock.NewConnection( @@ -91,13 +92,16 @@ func TestConnectionManagerConnError(t *testing.T) { if err != nil { t.Fatalf("unexpected error when creating Ouroboros object: %s", err) } - connManager.AddConnection(i, oConn) + if i == testIdx { + expectedConnId = oConn.Id() + } + connManager.AddConnection(oConn) } select { case <-doneChan: // Shutdown other connections for _, tmpConn := range connManager.GetConnectionsByTags() { - if tmpConn.Id != expectedConnId { + if tmpConn.Conn.Id() != expectedConnId { tmpConn.Conn.Close() } } @@ -111,11 +115,11 @@ func TestConnectionManagerConnError(t *testing.T) { func TestConnectionManagerConnClosed(t *testing.T) { defer goleak.VerifyNone(t) - expectedConnId := 42 + var expectedConnId ouroboros.ConnectionId doneChan := make(chan any) connManager := ouroboros.NewConnectionManager( ouroboros.ConnectionManagerConfig{ - ConnClosedFunc: func(connId int, err error) { + ConnClosedFunc: func(connId ouroboros.ConnectionId, err error) { if connId != expectedConnId { t.Fatalf("did not receive closed signal from expected connection: got %d, wanted %d", connId, expectedConnId) } @@ -142,7 +146,8 @@ func TestConnectionManagerConnClosed(t *testing.T) { if err != nil { t.Fatalf("unexpected error when creating Ouroboros object: %s", err) } - connManager.AddConnection(expectedConnId, oConn) + expectedConnId = oConn.Id() + connManager.AddConnection(oConn) time.AfterFunc( 1*time.Second, func() { diff --git a/go.mod b/go.mod index 7e66bc41..76c4c1e5 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 toolchain go1.21.5 require ( - github.com/blinklabs-io/ouroboros-mock v0.2.0 + github.com/blinklabs-io/ouroboros-mock v0.3.0 github.com/fxamacker/cbor/v2 v2.6.0 github.com/jinzhu/copier v0.4.0 github.com/utxorpc/go-codegen v0.4.4 diff --git a/go.sum b/go.sum index 90c92bbe..8e039940 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/blinklabs-io/ouroboros-mock v0.2.0 h1:Wff7mJiFUzktQ5tuWRN9vXNk38wR0ij2Q4bYHwJXaV4= -github.com/blinklabs-io/ouroboros-mock v0.2.0/go.mod h1:t9eIDjmj339GJtfV7jandJnCqmj8WkZsFg2N1TR68io= +github.com/blinklabs-io/ouroboros-mock v0.3.0 h1:6VRWyhAv0k7nQEgzFpuqhS/n8OM+OAaLN/sCT5K2Hbc= +github.com/blinklabs-io/ouroboros-mock v0.3.0/go.mod h1:0dzTNEk/Kvqa7qYHDy7/Nn3OTt+EOosMknB37FRzI1k= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=