From 7ac6ef95ddc465a025fe0f114a8525e5cada1ecb Mon Sep 17 00:00:00 2001 From: Sebastien Binet Date: Tue, 21 Jan 2020 10:36:51 +0100 Subject: [PATCH] zmq4: do not panic b/c of invalid connection attempts This CL also makes sure we exercize invalid connection attempts and still end up with a valid, non-corrupted, socket state. Updates go-zeromq/zmq4#56. --- conn.go | 2 +- socket.go | 9 ++- socket_test.go | 174 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 4 deletions(-) create mode 100644 socket_test.go diff --git a/conn.go b/conn.go index 87c32a6..9c67c6e 100644 --- a/conn.go +++ b/conn.go @@ -89,7 +89,7 @@ func Open(rw net.Conn, sec Security, sockType SocketType, sockID SocketIdentity, err := conn.init(sec) if err != nil { - return nil, err + return nil, xerrors.Errorf("zmq4: could not initialize ZMTP connection: %w", err) } return conn, nil diff --git a/socket.go b/socket.go index 53a9d9d..c58b8c6 100644 --- a/socket.go +++ b/socket.go @@ -6,6 +6,7 @@ package zmq4 import ( "context" + "log" "net" "os" "strings" @@ -183,14 +184,16 @@ func (sck *socket) accept() { default: conn, err := sck.listener.Accept() if err != nil { - // log.Printf("zmq4: error accepting connection from %q: %v", sck.ep, err) + // FIXME(sbinet): maybe bubble up this error to application code? + // log.Printf("zmq4: error accepting connection from %q: %+v", sck.ep, err) continue } zconn, err := Open(conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn) if err != nil { - panic(err) - // return xerrors.Errorf("zmq4: could not open a ZMTP connection: %w", err) + // FIXME(sbinet): maybe bubble up this error to application code? + log.Printf("zmq4: could not open a ZMTP connection with %q: %+v", sck.ep, err) + continue } sck.addConn(zconn) diff --git a/socket_test.go b/socket_test.go new file mode 100644 index 0000000..f0877a3 --- /dev/null +++ b/socket_test.go @@ -0,0 +1,174 @@ +// Copyright 2020 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4_test + +import ( + "context" + "io" + "net" + "testing" + "time" + + "github.com/go-zeromq/zmq4" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" +) + +func TestInvalidConn(t *testing.T) { + t.Parallel() + + ep := must(EndPoint("tcp")) + cleanUp(ep) + + ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) + defer timeout() + + pub := zmq4.NewPub(ctx) + defer pub.Close() + + err := pub.Listen(ep) + if err != nil { + t.Fatalf("could not listen on end-point: %+v", err) + } + + grp, ctx := errgroup.WithContext(ctx) + grp.Go(func() error { + conn, err := net.Dial("tcp", ep[len("tcp://"):]) + if err != nil { + return xerrors.Errorf("could not dial %q: %w", ep, err) + } + defer conn.Close() + var reply = make([]byte, 64) + _, err = io.ReadFull(conn, reply) + if err != nil { + return xerrors.Errorf("could not read reply bytes...: %w", err) + } + _, err = conn.Write(make([]byte, 64)) + if err != nil { + return xerrors.Errorf("could not send bytes...: %w", err) + } + time.Sleep(1 * time.Second) // FIXME(sbinet): hugly. + return nil + }) + + if err := grp.Wait(); err != nil { + t.Fatalf("error: %+v", err) + } + + if err := ctx.Err(); err != nil && err != context.Canceled { + t.Fatalf("error: %+v", err) + } +} + +func TestConnPairs(t *testing.T) { + t.Parallel() + + bkg := context.Background() + + for _, tc := range []struct { + name string + srv zmq4.Socket + wrong zmq4.Socket + cli zmq4.Socket + }{ + { + name: "pair", + srv: zmq4.NewPair(bkg), + wrong: zmq4.NewSub(bkg), + cli: zmq4.NewPair(bkg), + }, + { + name: "pub", + srv: zmq4.NewPub(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewSub(bkg), + }, + { + name: "sub", + srv: zmq4.NewSub(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewPub(bkg), + }, + { + name: "req", + srv: zmq4.NewReq(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewRep(bkg), + }, + { + name: "rep", + srv: zmq4.NewRep(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewReq(bkg), + }, + { + name: "dealer", + srv: zmq4.NewDealer(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewRouter(bkg), + }, + { + name: "router", + srv: zmq4.NewRouter(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewDealer(bkg), + }, + { + name: "pull", + srv: zmq4.NewPull(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewPush(bkg), + }, + { + name: "push", + srv: zmq4.NewPush(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewPull(bkg), + }, + { + name: "xpub", + srv: zmq4.NewXPub(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewXSub(bkg), + }, + { + name: "xsub", + srv: zmq4.NewXSub(bkg), + wrong: zmq4.NewPair(bkg), + cli: zmq4.NewXPub(bkg), + }, + } { + t.Run(tc.name, func(t *testing.T) { + ep := must(EndPoint("tcp")) + cleanUp(ep) + + _, timeout := context.WithTimeout(bkg, 20*time.Second) + defer timeout() + + defer tc.srv.Close() + defer tc.wrong.Close() + defer tc.cli.Close() + + err := tc.srv.Listen(ep) + if err != nil { + t.Fatalf("could not listen on %q: %+v", ep, err) + } + + err = tc.wrong.Dial(ep) + if err == nil { + t.Fatalf("dialed %q", ep) + } + want := xerrors.Errorf("zmq4: could not open a ZMTP connection: zmq4: could not initialize ZMTP connection: zmq4: peer=%q not compatible with %q", tc.srv.Type(), tc.wrong.Type()) + if got, want := err.Error(), want.Error(); got != want { + t.Fatalf("invalid error:\ngot = %v\nwant= %v", got, want) + } + + err = tc.cli.Dial(ep) + if err != nil { + t.Fatalf("could not dial %q: %+v", ep, err) + } + }) + } +}