diff --git a/bucket.go b/bucket.go index f780c6e..a6c4daa 100644 --- a/bucket.go +++ b/bucket.go @@ -2,38 +2,42 @@ package tcp import ( "sync" + "time" ) +//TCPConnBucket 用来存放和管理TCPConn连接 type TCPConnBucket struct { m map[string]*TCPConn mu *sync.RWMutex } -func newTCPConnBucket() *TCPConnBucket { - return &TCPConnBucket{ +func NewTCPConnBucket() *TCPConnBucket { + tcb := &TCPConnBucket{ m: make(map[string]*TCPConn), mu: new(sync.RWMutex), } + tcb.removeClosedTCPConnLoop() + return tcb } -func (b *TCPConnBucket) Put(key string, c *TCPConn) { +func (b *TCPConnBucket) Put(id string, c *TCPConn) { b.mu.Lock() - b.m[key] = c + b.m[id] = c b.mu.Unlock() } -func (b *TCPConnBucket) Get(key string) *TCPConn { +func (b *TCPConnBucket) Get(id string) *TCPConn { b.mu.RLock() defer b.mu.RUnlock() - if conn, ok := b.m[key]; ok { + if conn, ok := b.m[id]; ok { return conn } return nil } -func (b *TCPConnBucket) Delete(key string) { +func (b *TCPConnBucket) Delete(id string) { b.mu.Lock() - delete(b.m, key) + delete(b.m, id) b.mu.Unlock() } func (b *TCPConnBucket) GetAll() map[string]*TCPConn { @@ -45,3 +49,18 @@ func (b *TCPConnBucket) GetAll() map[string]*TCPConn { } return m } + +func (b *TCPConnBucket) removeClosedTCPConnLoop() { + go func() { + removeKey := make(map[string]struct{}) + for key, conn := range b.GetAll() { + if conn.IsClosed() { + removeKey[key] = struct{}{} + } + } + for key := range removeKey { + b.Delete(key) + } + time.Sleep(time.Millisecond * 100) + }() +} diff --git a/callback.go b/callback.go index a14b35e..cc330b0 100644 --- a/callback.go +++ b/callback.go @@ -1,9 +1,13 @@ package tcp +//CallBack 是一个回调接口,用于连接的各种事件处理 type CallBack interface { + //链接建立回调 OnConnected(conn *TCPConn) - + //消息处理回调 OnMessage(conn *TCPConn, p Packet) - + //链接断开回调 OnDisconnected(conn *TCPConn) + //错误回调 + // OnError(err error, conn *TCPConn) } diff --git a/tcp_conn.go b/tcp_conn.go index ab054f7..4d585fe 100644 --- a/tcp_conn.go +++ b/tcp_conn.go @@ -66,6 +66,7 @@ func (c *TCPConn) readLoop() { default: p, err := c.protocol.ReadPacket(c.conn) if err != nil { + // c.callback.OnError(err, c) return } c.readChan <- p @@ -95,6 +96,7 @@ func (c *TCPConn) writeLoop() { continue } if err := c.protocol.WritePacket(c.conn, p); err != nil { + // c.callback.OnError(err, c) return } } diff --git a/tcp_server.go b/tcp_server.go index 7629ca1..d90ddb1 100644 --- a/tcp_server.go +++ b/tcp_server.go @@ -24,6 +24,7 @@ func init() { logger = log.New(os.Stdout, "", log.Lshortfile) } +//TCPServer 结构定义 type TCPServer struct { //TCP address to listen on tcpAddr string @@ -44,18 +45,20 @@ type TCPServer struct { bucket *TCPConnBucket } +//NewTCPServer 返回一个TCPServer实例 func NewTCPServer(tcpAddr string, callback CallBack, protocol Protocol) *TCPServer { return &TCPServer{ tcpAddr: tcpAddr, callback: callback, protocol: protocol, - bucket: newTCPConnBucket(), + bucket: NewTCPConnBucket(), exitChan: make(chan struct{}), maxPacketSize: defaultMaxPacketSize, } } +//ListenAndServe 使用TCPServer的tcpAddr创建TCPListner并调用Server()方法开启监听 func (srv *TCPServer) ListenAndServe() error { tcpAddr, err := net.ResolveTCPAddr("tcp4", srv.tcpAddr) if err != nil { @@ -68,6 +71,7 @@ func (srv *TCPServer) ListenAndServe() error { return srv.Serve(ln) } +//Serve 使用指定的TCPListener开启监听 func (srv *TCPServer) Serve(l *net.TCPListener) error { srv.listener = l defer func() { @@ -76,12 +80,12 @@ func (srv *TCPServer) Serve(l *net.TCPListener) error { } srv.listener.Close() }() - go func() { - for { - srv.removeClosedTCPConn() - time.Sleep(time.Millisecond * 10) - } - }() + // go func() { + // for { + // srv.removeClosedTCPConn() + // time.Sleep(time.Millisecond * 10) + // } + // }() var tempDelay time.Duration for { @@ -121,6 +125,7 @@ func (srv *TCPServer) newTCPConn(conn *net.TCPConn, callback CallBack, protocol return NewTCPConn(conn, callback, protocol) } +//Connect 使用指定的callback和protocol连接其他TCPServer,返回TCPConn func (srv *TCPServer) Connect(ip string, callback CallBack, protocol Protocol) (*TCPConn, error) { tcpAddr, err := net.ResolveTCPAddr("tcp", ip) if err != nil { @@ -136,6 +141,7 @@ func (srv *TCPServer) Connect(ip string, callback CallBack, protocol Protocol) ( } +//Close 首先关闭所有连接,然后关闭TCPServer func (srv *TCPServer) Close() { defer srv.listener.Close() for _, c := range srv.bucket.GetAll() { @@ -145,25 +151,27 @@ func (srv *TCPServer) Close() { } } -func (srv *TCPServer) removeClosedTCPConn() { - for { - select { - case <-srv.exitChan: - return - default: - removeKey := make(map[string]struct{}) - for key, conn := range srv.bucket.GetAll() { - if conn.IsClosed() { - removeKey[key] = struct{}{} - } - } - for key, _ := range removeKey { - srv.bucket.Delete(key) - } - time.Sleep(time.Millisecond * 10) - } - } -} +// func (srv *TCPServer) removeClosedTCPConn() { +// for { +// select { +// case <-srv.exitChan: +// return +// default: +// removeKey := make(map[string]struct{}) +// for key, conn := range srv.bucket.GetAll() { +// if conn.IsClosed() { +// removeKey[key] = struct{}{} +// } +// } +// for key := range removeKey { +// srv.bucket.Delete(key) +// } +// time.Sleep(time.Millisecond * 10) +// } +// } +// } + +//GetAllTCPConn 返回所有客户端连接 func (srv *TCPServer) GetAllTCPConn() []*TCPConn { conns := []*TCPConn{} for _, conn := range srv.bucket.GetAll() {