diff --git a/conn.go b/conn.go deleted file mode 100644 index 9678ccc..0000000 --- a/conn.go +++ /dev/null @@ -1,190 +0,0 @@ -package tcp - -import ( - "errors" - "io" - "net" - "strings" - "sync" - "sync/atomic" - "time" -) - -var ( - ErrConnClosing = errors.New("use of closed network connection") - ErrBufferFull = errors.New("the async send buffer is full") -) - -type TCPConn struct { - callback CallBack - protocol Protocol - - conn *net.TCPConn - readChan chan Packet - writeChan chan Packet - errChan chan error - - readDeadline time.Duration - writeDeadline time.Duration - maxPacketSize uint32 - - exitChan chan struct{} - closeOnce sync.Once - exitFlag int32 -} - -func NewTCPConn(conn *net.TCPConn, callback CallBack, protocol Protocol) *TCPConn { - c := &TCPConn{ - conn: conn, - callback: callback, - protocol: protocol, - - readChan: make(chan Packet, readChanSize), - writeChan: make(chan Packet, writeChanSize), - errChan: make(chan error, 10), - - exitChan: make(chan struct{}), - exitFlag: 0, - } - return c -} - -func (c *TCPConn) Serve() { - defer func() { - if r := recover(); r != nil { - logger.Println("tcp conn(%v) Serve error, %v ", c.GetRemoteIPAddress(), r) - } - }() - atomic.StoreInt32(&c.exitFlag, 1) - c.callback.OnConnected(c) - go c.readLoop() - go c.writeLoop() - go c.handleLoop() -} - -func (c *TCPConn) readLoop() { - defer func() { - recover() - c.Close() - }() - - for { - select { - case <-c.exitChan: - return - default: - if c.readDeadline > 0 { - c.conn.SetReadDeadline(time.Now().Add(c.readDeadline)) - } - p, err := c.protocol.ReadPacket(c.conn) - if err != nil { - if err != io.EOF { - c.errChan <- err - } - return - } - c.readChan <- p - } - } -} - -func (c *TCPConn) ReadPacket() (Packet, error) { - if c.protocol == nil { - return nil, errors.New("no protocol impl") - } - return c.protocol.ReadPacket(c.conn) -} - -func (c *TCPConn) writeLoop() { - defer func() { - recover() - c.Close() - }() - - for pkt := range c.writeChan { - if pkt == nil { - continue - } - if c.writeDeadline > 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeDeadline)) - } - if err := c.protocol.WritePacket(c.conn, pkt); err != nil { - c.errChan <- err - return - } - } -} - -func (c *TCPConn) handleLoop() { - defer func() { - recover() - c.Close() - }() - for p := range c.readChan { - if p == nil { - continue - } - c.callback.OnMessage(c, p) - } -} - -func (c *TCPConn) AsyncWritePacket(p Packet) error { - if c.IsClosed() { - return ErrConnClosing - } - select { - case c.writeChan <- p: - return nil - default: - return ErrBufferFull - } -} - -func (c *TCPConn) Close() { - c.closeOnce.Do(func() { - close(c.exitChan) - close(c.errChan) - close(c.writeChan) - close(c.readChan) - c.callback.OnDisconnected(c) - atomic.StoreInt32(&c.exitFlag, 0) - c.conn.Close() - }) -} - -func (c *TCPConn) Errors() <-chan error { - return c.errChan -} - -func (c *TCPConn) GetRawConn() *net.TCPConn { - return c.conn -} - -func (c *TCPConn) IsClosed() bool { - return atomic.LoadInt32(&c.exitFlag) == 0 -} - -func (c *TCPConn) GetLocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -//LocalIPAddress 返回socket连接本地的ip地址 -func (c *TCPConn) GetLocalIPAddress() string { - return strings.Split(c.GetLocalAddr().String(), ":")[0] -} - -func (c *TCPConn) GetRemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *TCPConn) GetRemoteIPAddress() string { - return strings.Split(c.GetRemoteAddr().String(), ":")[0] -} - -func (c *TCPConn) setReadDeadline(t time.Duration) { - c.readDeadline = t -} - -func (c *TCPConn) setWriteDeadline(t time.Duration) { - c.writeDeadline = t -} diff --git a/tcp_conn.go b/tcp_conn.go index 9678ccc..4df98da 100644 --- a/tcp_conn.go +++ b/tcp_conn.go @@ -2,6 +2,7 @@ package tcp import ( "errors" + "fmt" "io" "net" "strings" @@ -26,7 +27,6 @@ type TCPConn struct { readDeadline time.Duration writeDeadline time.Duration - maxPacketSize uint32 exitChan chan struct{} closeOnce sync.Once @@ -49,17 +49,23 @@ func NewTCPConn(conn *net.TCPConn, callback CallBack, protocol Protocol) *TCPCon return c } -func (c *TCPConn) Serve() { +func (c *TCPConn) Serve() error { defer func() { if r := recover(); r != nil { logger.Println("tcp conn(%v) Serve error, %v ", c.GetRemoteIPAddress(), r) } }() + if c.callback == nil || c.protocol == nil { + err := fmt.Errorf("callback and protocol are not allowed to be nil") + c.Close() + return err + } atomic.StoreInt32(&c.exitFlag, 1) c.callback.OnConnected(c) go c.readLoop() go c.writeLoop() go c.handleLoop() + return nil } func (c *TCPConn) readLoop() { @@ -146,7 +152,9 @@ func (c *TCPConn) Close() { close(c.errChan) close(c.writeChan) close(c.readChan) - c.callback.OnDisconnected(c) + if c.callback != nil { + c.callback.OnDisconnected(c) + } atomic.StoreInt32(&c.exitFlag, 0) c.conn.Close() })