Skip to content

Commit

Permalink
fix reading directly from connection
Browse files Browse the repository at this point in the history
  • Loading branch information
raspi committed Jul 17, 2022
1 parent 0f3063c commit ec95c15
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 45 deletions.
2 changes: 1 addition & 1 deletion example/cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func main() {
_, _ = fmt.Fprintf(os.Stderr, `error: %v`, err)
os.Exit(1)
}
defer c.Close()

go c.Listen()

Expand All @@ -67,5 +68,4 @@ func main() {
}
}

c.Close()
}
3 changes: 1 addition & 2 deletions example/cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func main() {
if err != nil {
panic(err)
}
defer l.Close()

go l.Listen()

Expand All @@ -35,6 +36,4 @@ func main() {
}
}

l.Close()

}
55 changes: 36 additions & 19 deletions pkg/server/internal/serverclient/client.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
package serverclient

import (
"bytes"
"encoding/binary"
error2 "github.com/raspi/jumiks/pkg/server/error"
"github.com/raspi/jumiks/pkg/server/header"
"log"
"math/rand"
"net"
"os"
"time"
)

// ServerClient is a client connected to a server
// it is used internally to track connected client
type ServerClient struct {
conn *net.UnixConn
connectedAt time.Time
lastPacket uint64
uuid uint64
writeMessages chan []byte
errors chan error2.Error
logger *log.Logger
conn *net.UnixConn
connectedAt time.Time
lastPacket uint64 // Last packet ID processed
uuid uint64 // UUID for this client
writeMessages chan []byte
errors chan error2.Error
tooSlowPacketsBehind uint64
}

func NewClient(conn *net.UnixConn, errors chan error2.Error) (c *ServerClient) {
func NewClient(conn *net.UnixConn, errors chan error2.Error, tooSlowPacketsBehind uint64) (c *ServerClient) {
c = &ServerClient{
logger: log.New(os.Stdout, `client `, log.LstdFlags),
conn: conn,
connectedAt: time.Now(),
uuid: rand.Uint64(),
errors: errors,
writeMessages: make(chan []byte),
lastPacket: 0,
conn: conn,
connectedAt: time.Now(),
uuid: rand.Uint64(),
errors: errors,
writeMessages: make(chan []byte),
lastPacket: 0,
tooSlowPacketsBehind: tooSlowPacketsBehind,
}

return c
Expand All @@ -40,15 +39,33 @@ func NewClient(conn *net.UnixConn, errors chan error2.Error) (c *ServerClient) {
func (c *ServerClient) Listen() {
defer c.conn.Close()

buffer := make([]byte, 1048576)

for {
c.logger.Printf(`reading message header `)
rb, err := c.conn.Read(buffer)
if err != nil {
c.errors <- error2.New(err)
break
}

if rb == 0 {
continue
}

buf := bytes.NewBuffer(buffer[:rb])

var hdr header.MessageHeaderFromClient
err := binary.Read(c.conn, binary.LittleEndian, &hdr)
err = binary.Read(buf, binary.LittleEndian, &hdr)
if err != nil {
c.errors <- error2.New(err)
break
}

if c.lastPacket != 0 && (hdr.PacketId-c.lastPacket) > c.tooSlowPacketsBehind {
// We are too slow
break
}

// Track packet ID
c.lastPacket = hdr.PacketId
}
Expand Down
33 changes: 10 additions & 23 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
error2 "github.com/raspi/jumiks/pkg/server/error"
"github.com/raspi/jumiks/pkg/server/header"
"github.com/raspi/jumiks/pkg/server/internal/serverclient"
"log"
"net"
"os"
"sync"
"syscall"
)
Expand All @@ -22,15 +20,13 @@ const ConnType = `unixpacket`
const StartPacketId = uint64(100000)

type Server struct {
logger *log.Logger
listener *net.UnixListener // Listening unix domain socket
clients map[uint64]*serverclient.ServerClient // Connected clients
packetId uint64 // Packet tracking ID
errch chan error2.Error // Errors
tooSlowPacketsBehind uint64 // How many packets can connected client lag behind
messagesCh chan []byte // messages sent to connected clients
connectionNew chan *net.UnixConn
connectionClose chan uint64
lock sync.Mutex
}

Expand All @@ -53,16 +49,14 @@ func New(name string, tooSlowPacketsBehind uint64, errch chan error2.Error) (s *
conn.SetUnlinkOnClose(true)

s = &Server{
logger: log.New(os.Stdout, ``, log.LstdFlags),
listener: conn,
packetId: StartPacketId,
errch: errch,
tooSlowPacketsBehind: tooSlowPacketsBehind,
clients: make(map[uint64]*serverclient.ServerClient),
clients: make(map[uint64]*serverclient.ServerClient), // connected clients
connectionNew: make(chan *net.UnixConn),
connectionClose: make(chan uint64),
messagesCh: make(chan []byte),
lock: sync.Mutex{},
messagesCh: make(chan []byte), // Messages sent to all connected clients
lock: sync.Mutex{}, // Lock when global state changes
}

return s, nil
Expand All @@ -73,13 +67,11 @@ func (s *Server) listenConnections() {

for {
// New connection
s.logger.Printf(`listening for new connection`)
conn, err := s.listener.AcceptUnix()
if err != nil {
s.errch <- error2.New(err)
continue
}
s.logger.Printf(`new connection, sending handshake`)

// handshake for determining that client speaks the same protocol

Expand Down Expand Up @@ -111,7 +103,6 @@ func (s *Server) listenConnections() {
continue
}

s.logger.Printf(`handshake ok`)
s.connectionNew <- conn
}
}
Expand Down Expand Up @@ -140,33 +131,30 @@ func generateMsg(pId uint64, msg []byte) []byte {
}

func (s *Server) Listen() {
// Listen on new connections (non-blocking)
go s.listenConnections()

for {
select {
case err := <-s.errch:
s.logger.Printf(`error: %v`, err)
case msg := <-s.messagesCh: // new message
s.logger.Printf(`received message from channel`)

for clientId, client := range s.clients {
s.logger.Printf(`client %v`, client.GetId())
if client == nil {
s.logger.Printf(`client is nil!`)

client.Close()
delete(s.clients, clientId)
continue
}

// Send the buffer to client
s.logger.Printf(`writing to client`)
wb, err := client.Write(msg)
if err != nil {
if errors.Is(err, syscall.EPIPE) {
client.Close()
delete(s.clients, clientId)
continue
} else if errors.Is(err, net.ErrClosed) {
client.Close()
delete(s.clients, clientId)
continue
}

panic(err)
Expand All @@ -178,11 +166,9 @@ func (s *Server) Listen() {
}

case conn := <-s.connectionNew:
s.logger.Printf(`adding client`)
c := serverclient.NewClient(conn, s.errch)
c := serverclient.NewClient(conn, s.errch, s.tooSlowPacketsBehind)
go c.Listen()
s.clients[c.GetId()] = c
s.logger.Printf(`client added`)

default:

Expand All @@ -199,6 +185,7 @@ func (s *Server) SendToAll(msg []byte) {
s.lock.Unlock()
}

// Close closes all connected clients and then closes the listener
func (s *Server) Close() error {
var del []uint64

Expand Down

0 comments on commit ec95c15

Please sign in to comment.