diff --git a/pps.go b/pps.go index 39fe839..7d38547 100644 --- a/pps.go +++ b/pps.go @@ -21,6 +21,17 @@ const DefaultAddr = "0.0.0.0" // DefaultPort is the default port the server is listening on const DefaultPort = "10005" +// CtxKey represents the different key ids for values added to contexts +type CtxKey int + +const ( + // CtxConnId represents the connection id in the connection context + CtxConnId CtxKey = iota +) + +// PostfixResp is a possible response value for the policy request +type PostfixResp string + // Possible responses to the postfix server // See: http://www.postfix.org/access.5.html const ( @@ -171,9 +182,6 @@ type polSetFunc func(*PolicySet, string) // ServerOpt is an override function for the New() method type ServerOpt func(*Server) -// PostfixResp is a possible response value for the policy request -type PostfixResp string - // Handler interface for handling incoming policy requests and returning the // corresponding action type Handler interface { @@ -219,9 +227,7 @@ func (s *Server) Run(ctx context.Context, h Handler) error { return err } go func() { - select { - case <-ctx.Done(): - } + <-ctx.Done() if err := l.Close(); err != nil { el.Printf("failed to close listener: %s", err) } @@ -242,7 +248,7 @@ func (s *Server) Run(ctx context.Context, h Handler) error { } connId := xid.New() - conCtx := context.WithValue(ctx, "id", connId) + conCtx := context.WithValue(ctx, CtxConnId, connId) go connHandler(conCtx, conn) } @@ -252,7 +258,7 @@ func (s *Server) Run(ctx context.Context, h Handler) error { // connHandler processes the incoming policy connection request and hands it to the // Handle function of the Handler interface func connHandler(ctx context.Context, c *Connection) { - connId, ok := ctx.Value("id").(xid.ID) + connId, ok := ctx.Value(CtxConnId).(xid.ID) if !ok { log.Print("failed to retrieve connection id from context.") return @@ -294,7 +300,7 @@ func connHandler(ctx context.Context, c *Connection) { c.err = err cc <- true } - l = strings.TrimRight(l, "\n\n") + l = strings.TrimRight(l, "\n") if l == "" { break } diff --git a/pps_test.go b/pps_test.go index 2384ee6..f3ae676 100644 --- a/pps_test.go +++ b/pps_test.go @@ -122,7 +122,7 @@ func TestRun(t *testing.T) { // TestRunDial starts a new server listening for connections and tries to connect to it func TestRunDial(t *testing.T) { - s := New() + s := New(WithPort("44440")) sctx, scancel := context.WithCancel(context.Background()) defer scancel() @@ -148,15 +148,12 @@ func TestRunDial(t *testing.T) { if err := conn.Close(); err != nil { t.Errorf("failed to close client connection: %s", err) } - - // Wait a brief moment for the connection to close - time.Sleep(time.Millisecond * 500) } // TestRunDialWithRequest starts a new server listening for connections and tries to connect to it // and sends example data func TestRunDialWithRequest(t *testing.T) { - s := New() + s := New(WithPort("44441")) sctx, scancel := context.WithCancel(context.Background()) defer scancel()