Skip to content

Commit

Permalink
zmq4: do not panic b/c of invalid connection attempts
Browse files Browse the repository at this point in the history
This CL also makes sure we exercize invalid connection attempts and
still end up with a valid, non-corrupted, socket state.

Updates go-zeromq#56.
  • Loading branch information
sbinet committed Jan 21, 2020
1 parent 28043d4 commit 7ac6ef9
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 4 deletions.
2 changes: 1 addition & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func Open(rw net.Conn, sec Security, sockType SocketType, sockID SocketIdentity,

err := conn.init(sec)
if err != nil {
return nil, err
return nil, xerrors.Errorf("zmq4: could not initialize ZMTP connection: %w", err)
}

return conn, nil
Expand Down
9 changes: 6 additions & 3 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zmq4

import (
"context"
"log"
"net"
"os"
"strings"
Expand Down Expand Up @@ -183,14 +184,16 @@ func (sck *socket) accept() {
default:
conn, err := sck.listener.Accept()
if err != nil {
// log.Printf("zmq4: error accepting connection from %q: %v", sck.ep, err)
// FIXME(sbinet): maybe bubble up this error to application code?
// log.Printf("zmq4: error accepting connection from %q: %+v", sck.ep, err)
continue
}

zconn, err := Open(conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn)
if err != nil {
panic(err)
// return xerrors.Errorf("zmq4: could not open a ZMTP connection: %w", err)
// FIXME(sbinet): maybe bubble up this error to application code?
log.Printf("zmq4: could not open a ZMTP connection with %q: %+v", sck.ep, err)
continue
}

sck.addConn(zconn)
Expand Down
174 changes: 174 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zmq4_test

import (
"context"
"io"
"net"
"testing"
"time"

"github.com/go-zeromq/zmq4"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
)

func TestInvalidConn(t *testing.T) {
t.Parallel()

ep := must(EndPoint("tcp"))
cleanUp(ep)

ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second)
defer timeout()

pub := zmq4.NewPub(ctx)
defer pub.Close()

err := pub.Listen(ep)
if err != nil {
t.Fatalf("could not listen on end-point: %+v", err)
}

grp, ctx := errgroup.WithContext(ctx)
grp.Go(func() error {
conn, err := net.Dial("tcp", ep[len("tcp://"):])
if err != nil {
return xerrors.Errorf("could not dial %q: %w", ep, err)
}
defer conn.Close()
var reply = make([]byte, 64)
_, err = io.ReadFull(conn, reply)
if err != nil {
return xerrors.Errorf("could not read reply bytes...: %w", err)
}
_, err = conn.Write(make([]byte, 64))
if err != nil {
return xerrors.Errorf("could not send bytes...: %w", err)
}
time.Sleep(1 * time.Second) // FIXME(sbinet): hugly.
return nil
})

if err := grp.Wait(); err != nil {
t.Fatalf("error: %+v", err)
}

if err := ctx.Err(); err != nil && err != context.Canceled {
t.Fatalf("error: %+v", err)
}
}

func TestConnPairs(t *testing.T) {
t.Parallel()

bkg := context.Background()

for _, tc := range []struct {
name string
srv zmq4.Socket
wrong zmq4.Socket
cli zmq4.Socket
}{
{
name: "pair",
srv: zmq4.NewPair(bkg),
wrong: zmq4.NewSub(bkg),
cli: zmq4.NewPair(bkg),
},
{
name: "pub",
srv: zmq4.NewPub(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewSub(bkg),
},
{
name: "sub",
srv: zmq4.NewSub(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewPub(bkg),
},
{
name: "req",
srv: zmq4.NewReq(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewRep(bkg),
},
{
name: "rep",
srv: zmq4.NewRep(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewReq(bkg),
},
{
name: "dealer",
srv: zmq4.NewDealer(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewRouter(bkg),
},
{
name: "router",
srv: zmq4.NewRouter(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewDealer(bkg),
},
{
name: "pull",
srv: zmq4.NewPull(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewPush(bkg),
},
{
name: "push",
srv: zmq4.NewPush(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewPull(bkg),
},
{
name: "xpub",
srv: zmq4.NewXPub(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewXSub(bkg),
},
{
name: "xsub",
srv: zmq4.NewXSub(bkg),
wrong: zmq4.NewPair(bkg),
cli: zmq4.NewXPub(bkg),
},
} {
t.Run(tc.name, func(t *testing.T) {
ep := must(EndPoint("tcp"))
cleanUp(ep)

_, timeout := context.WithTimeout(bkg, 20*time.Second)
defer timeout()

defer tc.srv.Close()
defer tc.wrong.Close()
defer tc.cli.Close()

err := tc.srv.Listen(ep)
if err != nil {
t.Fatalf("could not listen on %q: %+v", ep, err)
}

err = tc.wrong.Dial(ep)
if err == nil {
t.Fatalf("dialed %q", ep)
}
want := xerrors.Errorf("zmq4: could not open a ZMTP connection: zmq4: could not initialize ZMTP connection: zmq4: peer=%q not compatible with %q", tc.srv.Type(), tc.wrong.Type())
if got, want := err.Error(), want.Error(); got != want {
t.Fatalf("invalid error:\ngot = %v\nwant= %v", got, want)
}

err = tc.cli.Dial(ep)
if err != nil {
t.Fatalf("could not dial %q: %+v", ep, err)
}
})
}
}

0 comments on commit 7ac6ef9

Please sign in to comment.