From 0f3063c439ba0770945990a827088dbdaf8a5f9f Mon Sep 17 00:00:00 2001 From: raspi Date: Sun, 17 Jul 2022 10:06:12 +0300 Subject: [PATCH] fix --- example/cmd/client/main.go | 24 ++++++++++-- example/cmd/server/main.go | 2 + pkg/client/client.go | 11 ++++++ pkg/server/internal/serverclient/client.go | 5 +++ pkg/server/server.go | 44 +++++++++++----------- 5 files changed, 62 insertions(+), 24 deletions(-) diff --git a/example/cmd/client/main.go b/example/cmd/client/main.go index f45f762..06f6a39 100644 --- a/example/cmd/client/main.go +++ b/example/cmd/client/main.go @@ -1,9 +1,13 @@ package main import ( + "errors" "fmt" "github.com/raspi/jumiks/pkg/client" + "io" + "net" "os" + "syscall" "time" ) @@ -36,10 +40,14 @@ func (c *ExampleClient) on_msg(b []byte) { c.delay += time.Millisecond * 50 } +func (c *ExampleClient) Close() error { + return c.c.Close() +} + func main() { - errors := make(chan error) + errorch := make(chan error) - c, err := New("@test", errors) + c, err := New("@test", errorch) if err != nil { _, _ = fmt.Fprintf(os.Stderr, `error: %v`, err) os.Exit(1) @@ -47,7 +55,17 @@ func main() { go c.Listen() - for err := range errors { + for err := range errorch { fmt.Printf(`got error: %v`, err) + + if errors.Is(err, io.EOF) { + break + } else if errors.Is(err, syscall.EPIPE) { + break + } else if errors.Is(err, net.ErrClosed) { + break + } } + + c.Close() } diff --git a/example/cmd/server/main.go b/example/cmd/server/main.go index 848f070..e1ae045 100644 --- a/example/cmd/server/main.go +++ b/example/cmd/server/main.go @@ -35,4 +35,6 @@ func main() { } } + l.Close() + } diff --git a/pkg/client/client.go b/pkg/client/client.go index a5182ef..bfedfaf 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -9,6 +9,7 @@ import ( "github.com/raspi/jumiks/pkg/server/header" "io" "net" + "syscall" ) type Client struct { @@ -54,6 +55,12 @@ func (c *Client) Listen() { rb, err := c.conn.Read(buffer) if err != nil { if errors.Is(err, io.EOF) { + // EOF + break + } else if errors.Is(err, net.ErrClosed) { + break + } else if errors.Is(err, syscall.EPIPE) { + // Broken pipe break } @@ -127,6 +134,10 @@ func (c *Client) handleMsg(b []byte) { c.hfn(b) } +func (c *Client) Close() error { + return c.conn.Close() +} + // handshake determines if both server.Server and Client are speaking the same protocol func handshake(conn *net.UnixConn) (err error) { var serverHs header.Handshake diff --git a/pkg/server/internal/serverclient/client.go b/pkg/server/internal/serverclient/client.go index dcd7fd1..e65b6e5 100644 --- a/pkg/server/internal/serverclient/client.go +++ b/pkg/server/internal/serverclient/client.go @@ -41,6 +41,7 @@ func (c *ServerClient) Listen() { defer c.conn.Close() for { + c.logger.Printf(`reading message header `) var hdr header.MessageHeaderFromClient err := binary.Read(c.conn, binary.LittleEndian, &hdr) if err != nil { @@ -60,3 +61,7 @@ func (c *ServerClient) GetId() uint64 { func (c *ServerClient) Write(b []byte) (int, error) { return c.conn.Write(b) } + +func (c *ServerClient) Close() error { + return c.conn.Close() +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 9312c9a..faaa6b8 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -148,23 +148,24 @@ func (s *Server) Listen() { s.logger.Printf(`error: %v`, err) case msg := <-s.messagesCh: // new message s.logger.Printf(`received message from channel`) - s.lock.Lock() - pId := s.packetId - s.lock.Unlock() for clientId, client := range s.clients { + s.logger.Printf(`client %v`, client.GetId()) if client == nil { s.logger.Printf(`client is nil!`) - s.connectionClose <- clientId + + client.Close() + delete(s.clients, clientId) continue } // Send the buffer to client s.logger.Printf(`writing to client`) - wb, err := client.Write(generateMsg(pId, msg)) + wb, err := client.Write(msg) if err != nil { if errors.Is(err, syscall.EPIPE) { - s.connectionClose <- clientId + client.Close() + delete(s.clients, clientId) continue } @@ -183,18 +184,6 @@ func (s *Server) Listen() { s.clients[c.GetId()] = c s.logger.Printf(`client added`) - case cId := <-s.connectionClose: - s.logger.Printf(`!!!! disconnecting client #%v`, cId) - - /* - err := s.clients[cId].Close() - if err != nil { - s.errch <- error2.New(err) - } - - */ - - delete(s.clients, cId) default: } @@ -203,11 +192,24 @@ func (s *Server) Listen() { // SendToAll sends a message to every connected client func (s *Server) SendToAll(msg []byte) { - if len(msg) > 0 { - s.messagesCh <- msg - } + s.messagesCh <- generateMsg(s.packetId, msg) s.lock.Lock() s.packetId++ s.lock.Unlock() } + +func (s *Server) Close() error { + var del []uint64 + + for i, c := range s.clients { + c.Close() + del = append(del, i) + } + + for _, i := range del { + delete(s.clients, i) + } + + return s.listener.Close() +}