diff --git a/client/pkg/transport/transport.go b/client/pkg/transport/transport.go index 67170d7436d..74e63272ec1 100644 --- a/client/pkg/transport/transport.go +++ b/client/pkg/transport/transport.go @@ -18,12 +18,29 @@ import ( "context" "net" "net/http" + "net/url" + "os" "strings" "time" ) type unixTransport struct{ *http.Transport } +var httpTransportProxyParsingFunc = determineHTTPTransportProxyParsingFunc + +func determineHTTPTransportProxyParsingFunc() func(req *http.Request) (*url.URL, error) { + // according to the comment of http.ProxyFromEnvironment: if the proxy URL is "localhost" + // (with or without a port number), then a nil URL and nil error will be returned. + // Thus, we workaround this limitation by manually setting an ENV named E2E_TEST_FORWARD_PROXY_IP + // and parse the URL (which is a localhost in our case) + if forwardProxy, exists := os.LookupEnv("E2E_TEST_FORWARD_PROXY_IP"); exists { + return func(req *http.Request) (*url.URL, error) { + return url.Parse(forwardProxy) + } + } + return http.ProxyFromEnvironment +} + func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { cfg, err := info.ClientConfig() if err != nil { @@ -39,7 +56,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er } t := &http.Transport{ - Proxy: http.ProxyFromEnvironment, + Proxy: httpTransportProxyParsingFunc(), DialContext: (&net.Dialer{ Timeout: dialtimeoutd, LocalAddr: ipAddr, @@ -60,7 +77,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er return dialer.DialContext(ctx, "unix", addr) } tu := &http.Transport{ - Proxy: http.ProxyFromEnvironment, + Proxy: httpTransportProxyParsingFunc(), DialContext: dialContext, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: cfg, diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index cd84e4e64b5..84b154892fb 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -19,6 +19,8 @@ import ( "errors" "fmt" "io" + "log" + "math/bits" mrand "math/rand" "net" "net/http" @@ -44,11 +46,18 @@ var ( // latency spikes and packet drop or corruption. The proxy overhead is very // small overhead (<500μs per request). Please run tests to compute actual // overhead. +// +// Note that the current implementation is a forward proxy, thus, unix socket +// is not supported, due to the forwarding is done in L7, which requires +// properly constructed HTTP header and body +// +// Also, because we are forced to use TLS to communicate with the proxy server +// and using well-formed header to talk to the destination server, +// so in the L7 forward proxy design we drop features such as random packet +// modification, etc. type Server interface { - // From returns proxy source address in "scheme://host:port" format. - From() string - // To returns proxy destination address in "scheme://host:port" format. - To() string + // Listen returns proxy listen address in "scheme://host:port" format. + Listen() string // Ready returns when proxy is ready to serve. Ready() <-chan struct{} @@ -101,28 +110,38 @@ type Server interface { // UnblackholeRx removes blackhole operation on "receiving". UnblackholeRx() - // ResetListener closes and restarts listener. - ResetListener() error + // BlackholePeerTx drops all outgoing traffic of a peer. + BlackholePeerTx(peer url.URL) + // UnblackholePeerTx removes blackhole operation on "sending". + UnblackholePeerTx(peer url.URL) + + // BlackholePeerTx drops all incoming traffic of a peer. + BlackholePeerRx(peer url.URL) + // UnblackholePeerRx removes blackhole operation on "receiving". + UnblackholePeerRx(peer url.URL) } // ServerConfig defines proxy server configuration. type ServerConfig struct { Logger *zap.Logger - From url.URL - To url.URL + Listen url.URL TLSInfo transport.TLSInfo DialTimeout time.Duration BufferSize int RetryInterval time.Duration } +const ( + blackholePeerTypeNone uint8 = iota + blackholePeerTypeTx + blackholePeerTypeRx +) + type server struct { lg *zap.Logger - from url.URL - fromPort int - to url.URL - toPort int + listen url.URL + listenPort int tlsInfo transport.TLSInfo dialTimeout time.Duration @@ -134,11 +153,12 @@ type server struct { donec chan struct{} errc chan error - closeOnce sync.Once - closeWg sync.WaitGroup + closeOnce sync.Once + closeWg sync.WaitGroup + closeHijackedConn sync.WaitGroup listenerMu sync.RWMutex - listener net.Listener + listener *net.Listener modifyTxMu sync.RWMutex modifyTx func(data []byte) []byte @@ -151,6 +171,11 @@ type server struct { latencyRxMu sync.RWMutex latencyRx time.Duration + + blackholePeerMap map[int]uint8 // port number, blackhole type + blackholePeerMapMu sync.RWMutex + + httpServer *http.Server } // NewServer returns a proxy implementation with no iptables/tc dependencies. @@ -159,8 +184,7 @@ func NewServer(cfg ServerConfig) Server { s := &server{ lg: cfg.Logger, - from: cfg.From, - to: cfg.To, + listen: cfg.Listen, tlsInfo: cfg.TLSInfo, dialTimeout: cfg.DialTimeout, @@ -171,18 +195,13 @@ func NewServer(cfg ServerConfig) Server { readyc: make(chan struct{}), donec: make(chan struct{}), errc: make(chan error, 16), - } - _, fromPort, err := net.SplitHostPort(cfg.From.Host) - if err == nil { - s.fromPort, _ = strconv.Atoi(fromPort) - } - var toPort string - _, toPort, err = net.SplitHostPort(cfg.To.Host) - if err == nil { - s.toPort, _ = strconv.Atoi(toPort) + blackholePeerMap: make(map[int]uint8), } + var err error + var fromPort string + if s.dialTimeout == 0 { s.dialTimeout = defaultDialTimeout } @@ -193,163 +212,196 @@ func NewServer(cfg ServerConfig) Server { s.retryInterval = defaultRetryInterval } - if strings.HasPrefix(s.from.Scheme, "http") { - s.from.Scheme = "tcp" - } - if strings.HasPrefix(s.to.Scheme, "http") { - s.to.Scheme = "tcp" - } + // L7 is http (scheme), L4 is tcp (network listener) + addr := "" + if strings.HasPrefix(s.listen.Scheme, "http") { + s.listen.Scheme = "tcp" + + if _, fromPort, err = net.SplitHostPort(cfg.Listen.Host); err != nil { + s.errc <- err + s.Close() + return nil + } + if s.listenPort, err = strconv.Atoi(fromPort); err != nil { + s.errc <- err + s.Close() + return nil + } - addr := fmt.Sprintf(":%d", s.fromPort) - if s.fromPort == 0 { // unix - addr = s.from.Host + addr = fmt.Sprintf(":%d", s.listenPort) + } else { + panic(fmt.Sprintf("%s is not supported", s.listen.Scheme)) } + s.closeWg.Add(1) var ln net.Listener if !s.tlsInfo.Empty() { - ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo) + ln, err = transport.NewListener(addr, s.listen.Scheme, &s.tlsInfo) } else { - ln, err = net.Listen(s.from.Scheme, addr) + ln, err = net.Listen(s.listen.Scheme, addr) } if err != nil { s.errc <- err s.Close() - return s + return nil } - s.listener = ln - s.closeWg.Add(1) - go s.listenAndServe() + s.listener = &ln - s.lg.Info("started proxying", zap.String("from", s.From()), zap.String("to", s.To())) - return s -} + go func() { + defer s.closeWg.Done() -func (s *server) From() string { - return fmt.Sprintf("%s://%s", s.from.Scheme, s.from.Host) -} + s.httpServer = &http.Server{ + Handler: &serverHandler{s: s}, + } + + s.lg.Info("proxy is listening on", zap.String("listen on", s.Listen())) + close(s.readyc) + if err := s.httpServer.Serve(*s.listener); err != http.ErrServerClosed { + // always returns error. ErrServerClosed on graceful close + panic(fmt.Sprintf("startHTTPServer Serve(): %v", err)) + } + }() -func (s *server) To() string { - return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host) + s.lg.Info("started proxying", zap.String("listen on", s.Listen())) + return s } -// TODO: implement packet reordering from multiple TCP connections -// buffer packets per connection for awhile, reorder before transmit -// - https://github.com/etcd-io/etcd/issues/5614 -// - https://github.com/etcd-io/etcd/pull/6918#issuecomment-264093034 +type serverHandler struct { + s *server +} -func (s *server) listenAndServe() { - defer s.closeWg.Done() +func (sh *serverHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + hijacker, _ := resp.(http.Hijacker) + in, _, err := hijacker.Hijack() + if err != nil { + select { + case sh.s.errc <- err: + select { + case <-sh.s.donec: + return + default: + } + case <-sh.s.donec: + return + } + sh.s.lg.Debug("ServeHTTP hijack error", zap.Error(err)) + panic(err) + } + targetScheme := "tcp" + targetHost := req.URL.Host ctx := context.Background() - s.lg.Info("proxy is listening on", zap.String("from", s.From())) - close(s.readyc) - for { - s.listenerMu.RLock() - ln := s.listener - s.listenerMu.RUnlock() + /* + If the traffic to the destination is HTTPS, a CONNECT request will be sent + first (containing the intended destination HOST). + + If the traffic to the destination is HTTP, no CONNECT request will be sent + first. Only normal HTTP request is sent, with the HOST set to the final destination. + This will be troublesome since we need to manually forward the request to the + destination, and we can't do bte stream manipulation. + + Thus, we need to send the traffic to destination with HTTPS, allowing us to + handle byte streams. + */ + if req.Method == "CONNECT" { + // for CONNECT, we need to send 200 response back first + in.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n")) + } - in, err := ln.Accept() + var out net.Conn + if !sh.s.tlsInfo.Empty() { + var tp *http.Transport + tp, err = transport.NewTransport(sh.s.tlsInfo, sh.s.dialTimeout) if err != nil { select { - case s.errc <- err: + case sh.s.errc <- err: select { - case <-s.donec: + case <-sh.s.donec: return default: } - case <-s.donec: + case <-sh.s.donec: return } - s.lg.Debug("listener accept error", zap.Error(err)) - - if strings.HasSuffix(err.Error(), "use of closed network connection") { - select { - case <-time.After(s.retryInterval): - case <-s.donec: - return - } - s.lg.Debug("listener is closed; retry listening on", zap.String("from", s.From())) - - if err = s.ResetListener(); err != nil { - select { - case s.errc <- err: - select { - case <-s.donec: - return - default: - } - case <-s.donec: - return - } - s.lg.Warn("failed to reset listener", zap.Error(err)) - } - } - - continue + sh.s.lg.Debug("failed to get new Transport", zap.Error(err)) + return } - - var out net.Conn - if !s.tlsInfo.Empty() { - var tp *http.Transport - tp, err = transport.NewTransport(s.tlsInfo, s.dialTimeout) - if err != nil { - select { - case s.errc <- err: - select { - case <-s.donec: - return - default: - } - case <-s.donec: - return - } - continue + out, err = tp.DialContext(ctx, targetScheme, targetHost) + } else { + out, err = net.Dial(targetScheme, targetHost) + } + if err != nil { + select { + case sh.s.errc <- err: + select { + case <-sh.s.donec: + return + default: } - out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host) - } else { - out, err = net.Dial(s.to.Scheme, s.to.Host) + case <-sh.s.donec: + return } - if err != nil { + sh.s.lg.Debug("failed to dial", zap.Error(err)) + return + } + + var dstPort int + dstPort, err = getPort(out.RemoteAddr()) + if err != nil { + select { + case sh.s.errc <- err: select { - case s.errc <- err: - select { - case <-s.donec: - return - default: - } - case <-s.donec: + case <-sh.s.donec: return + default: } - s.lg.Debug("failed to dial", zap.Error(err)) - continue + case <-sh.s.donec: + return } + sh.s.lg.Debug("failed to parse port in transmit", zap.Error(err)) + return + } + + sh.s.closeHijackedConn.Add(2) + go func() { + defer sh.s.closeHijackedConn.Done() + // read incoming bytes from listener, dispatch to outgoing connection + sh.s.transmit(out, in, dstPort) + out.Close() + in.Close() + }() + go func() { + defer sh.s.closeHijackedConn.Done() + // read response from outgoing connection, write back to listener + sh.s.receive(in, out, dstPort) + in.Close() + out.Close() + }() +} - s.closeWg.Add(2) - go func() { - defer s.closeWg.Done() - // read incoming bytes from listener, dispatch to outgoing connection - s.transmit(out, in) - out.Close() - in.Close() - }() - go func() { - defer s.closeWg.Done() - // read response from outgoing connection, write back to listener - s.receive(in, out) - in.Close() - out.Close() - }() +func (s *server) Listen() string { + return fmt.Sprintf("%s://%s", s.listen.Scheme, s.listen.Host) +} + +func getPort(addr net.Addr) (int, error) { + switch addr := addr.(type) { + case *net.TCPAddr: + return addr.Port, nil + case *net.UDPAddr: + return addr.Port, nil + default: + return 0, fmt.Errorf("unsupported address type: %T", addr) } } -func (s *server) transmit(dst io.Writer, src io.Reader) { - s.ioCopy(dst, src, proxyTx) +func (s *server) transmit(dst, src net.Conn, port int) { + s.ioCopy(dst, src, proxyTx, port) } -func (s *server) receive(dst io.Writer, src io.Reader) { - s.ioCopy(dst, src, proxyRx) +func (s *server) receive(dst, src net.Conn, port int) { + s.ioCopy(dst, src, proxyRx, port) } type proxyType uint8 @@ -359,7 +411,7 @@ const ( proxyRx ) -func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) { +func (s *server) ioCopy(dst, src net.Conn, ptype proxyType, peerPort int) { buf := make([]byte, s.bufferSize) for { nr1, err := src.Read(buf) @@ -400,12 +452,30 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) { data = s.modifyTx(data) } s.modifyTxMu.RUnlock() + + s.blackholePeerMapMu.RLock() + // Tx from other peers is Rx for the target peer + if val, exist := s.blackholePeerMap[peerPort]; exist { + if (val & blackholePeerTypeRx) > 0 { + data = nil + } + } + s.blackholePeerMapMu.RUnlock() case proxyRx: s.modifyRxMu.RLock() if s.modifyRx != nil { data = s.modifyRx(data) } s.modifyRxMu.RUnlock() + + s.blackholePeerMapMu.RLock() + // Rx from other peers is Tx for the target peer + if val, exist := s.blackholePeerMap[peerPort]; exist { + if (val & blackholePeerTypeTx) > 0 { + data = nil + } + } + s.blackholePeerMapMu.RUnlock() default: panic("unknown proxy type") } @@ -413,19 +483,19 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) { switch ptype { case proxyTx: s.lg.Debug( - "modified tx", + "proxyTx", zap.String("data-received", humanize.Bytes(uint64(nr1))), zap.String("data-modified", humanize.Bytes(uint64(nr2))), - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), + zap.Int("to peer port", peerPort), ) case proxyRx: s.lg.Debug( - "modified rx", + "proxyRx", zap.String("data-received", humanize.Bytes(uint64(nr1))), zap.String("data-modified", humanize.Bytes(uint64(nr2))), - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), + zap.Int("to peer port", peerPort), ) default: panic("unknown proxy type") @@ -450,11 +520,27 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) { panic("unknown proxy type") } if lat > 0 { + s.lg.Debug( + "before delay TX/RX", + zap.String("data-received", humanize.Bytes(uint64(nr1))), + zap.String("data-modified", humanize.Bytes(uint64(nr2))), + zap.String("proxy listening on", s.Listen()), + zap.Int("to peer port", peerPort), + zap.Duration("latency", lat), + ) select { case <-time.After(lat): case <-s.donec: return } + s.lg.Debug( + "after delay TX/RX", + zap.String("data-received", humanize.Bytes(uint64(nr1))), + zap.String("data-modified", humanize.Bytes(uint64(nr2))), + zap.String("proxy listening on", s.Listen()), + zap.Int("to peer port", peerPort), + zap.Duration("latency", lat), + ) } // now forward packets to target @@ -522,15 +608,15 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) { s.lg.Debug( "transmitted", zap.String("data-size", humanize.Bytes(uint64(nr1))), - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), + zap.Int("to peer port", peerPort), ) case proxyRx: s.lg.Debug( "received", zap.String("data-size", humanize.Bytes(uint64(nr1))), - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), + zap.Int("to peer port", peerPort), ) default: panic("unknown proxy type") @@ -544,19 +630,27 @@ func (s *server) Error() <-chan error { return s.errc } func (s *server) Close() (err error) { s.closeOnce.Do(func() { close(s.donec) - s.listenerMu.Lock() - if s.listener != nil { - err = s.listener.Close() - s.lg.Info( - "closed proxy listener", - zap.String("from", s.From()), - zap.String("to", s.To()), - ) + + // we shutdown the server + log.Println("we shutdown the server") + if err = s.httpServer.Shutdown(context.TODO()); err != nil { + return } + s.httpServer = nil + + log.Println("waiting for listenerMu") + // listener was closed by the Shutdown() call + s.listenerMu.Lock() + s.listener = nil s.lg.Sync() s.listenerMu.Unlock() + + // the hijacked connections aren't tracked by the server so we need to wait for them + log.Println("waiting for closeHijackedConn") + s.closeHijackedConn.Wait() }) s.closeWg.Wait() + return err } @@ -574,8 +668,7 @@ func (s *server) DelayTx(latency, rv time.Duration) { zap.Duration("latency", d), zap.Duration("given-latency", latency), zap.Duration("given-latency-random-variable", rv), - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), ) } @@ -588,8 +681,7 @@ func (s *server) UndelayTx() { s.lg.Info( "removed transmit latency", zap.Duration("latency", d), - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), ) } @@ -614,8 +706,7 @@ func (s *server) DelayRx(latency, rv time.Duration) { zap.Duration("latency", d), zap.Duration("given-latency", latency), zap.Duration("given-latency-random-variable", rv), - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), ) } @@ -628,8 +719,7 @@ func (s *server) UndelayRx() { s.lg.Info( "removed receive latency", zap.Duration("latency", d), - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), ) } @@ -665,8 +755,7 @@ func (s *server) ModifyTx(f func([]byte) []byte) { s.lg.Info( "modifying tx", - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), ) } @@ -677,8 +766,7 @@ func (s *server) UnmodifyTx() { s.lg.Info( "unmodifyed tx", - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), ) } @@ -688,8 +776,7 @@ func (s *server) ModifyRx(f func([]byte) []byte) { s.modifyRxMu.Unlock() s.lg.Info( "modifying rx", - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), ) } @@ -700,8 +787,7 @@ func (s *server) UnmodifyRx() { s.lg.Info( "unmodifyed rx", - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), ) } @@ -709,8 +795,7 @@ func (s *server) BlackholeTx() { s.ModifyTx(func([]byte) []byte { return nil }) s.lg.Info( "blackholed tx", - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), ) } @@ -718,8 +803,7 @@ func (s *server) UnblackholeTx() { s.UnmodifyTx() s.lg.Info( "unblackholed tx", - zap.String("from", s.From()), - zap.String("to", s.To()), + zap.String("proxy listening on", s.Listen()), ) } @@ -727,8 +811,7 @@ func (s *server) BlackholeRx() { s.ModifyRx(func([]byte) []byte { return nil }) s.lg.Info( "blackholed rx", - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), ) } @@ -736,37 +819,66 @@ func (s *server) UnblackholeRx() { s.UnmodifyRx() s.lg.Info( "unblackholed rx", - zap.String("from", s.To()), - zap.String("to", s.From()), + zap.String("proxy listening on", s.Listen()), ) } -func (s *server) ResetListener() error { - s.listenerMu.Lock() - defer s.listenerMu.Unlock() +func (s *server) BlackholePeerTx(peer url.URL) { + s.blackholePeerMapMu.Lock() + defer s.blackholePeerMapMu.Unlock() - if err := s.listener.Close(); err != nil { - // already closed - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - return err - } + port, err := strconv.Atoi(peer.Port()) + if err != nil { + panic("port parsing failed") } - - var ln net.Listener - var err error - if !s.tlsInfo.Empty() { - ln, err = transport.NewListener(s.from.Host, s.from.Scheme, &s.tlsInfo) + if val, exist := s.blackholePeerMap[port]; exist { + val |= blackholePeerTypeTx + s.blackholePeerMap[port] = val } else { - ln, err = net.Listen(s.from.Scheme, s.from.Host) + s.blackholePeerMap[port] = blackholePeerTypeTx + } +} + +func (s *server) UnblackholePeerTx(peer url.URL) { + s.blackholePeerMapMu.Lock() + defer s.blackholePeerMapMu.Unlock() + + port, err := strconv.Atoi(peer.Port()) + if err != nil { + panic("port parsing failed") + } + if val, exist := s.blackholePeerMap[port]; exist { + val &= bits.Reverse8(blackholePeerTypeTx) + s.blackholePeerMap[port] = val } +} + +func (s *server) BlackholePeerRx(peer url.URL) { + s.blackholePeerMapMu.Lock() + defer s.blackholePeerMapMu.Unlock() + + port, err := strconv.Atoi(peer.Port()) if err != nil { - return err + panic("port parsing failed") + } + if val, exist := s.blackholePeerMap[port]; exist { + val |= blackholePeerTypeRx + s.blackholePeerMap[port] = val + } else { + s.blackholePeerMap[port] = blackholePeerTypeTx } - s.listener = ln +} - s.lg.Info( - "reset listener on", - zap.String("from", s.From()), - ) - return nil +func (s *server) UnblackholePeerRx(peer url.URL) { + s.blackholePeerMapMu.Lock() + defer s.blackholePeerMapMu.Unlock() + + port, err := strconv.Atoi(peer.Port()) + if err != nil { + panic("port parsing failed") + } + if val, exist := s.blackholePeerMap[port]; exist { + val &= bits.Reverse8(blackholePeerTypeRx) + s.blackholePeerMap[port] = val + } } diff --git a/pkg/proxy/server_test.go b/pkg/proxy/server_test.go index baabfebe488..b06e3b4bc79 100644 --- a/pkg/proxy/server_test.go +++ b/pkg/proxy/server_test.go @@ -17,575 +17,351 @@ package proxy import ( "bytes" "context" - "crypto/tls" - "fmt" "io" "log" - "math/rand" "net" "net/http" "net/url" - "os" "strings" "testing" "time" - "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zaptest" "go.etcd.io/etcd/client/pkg/v3/transport" ) -func TestServer_Unix_Insecure(t *testing.T) { testServer(t, "unix", false, false) } -func TestServer_TCP_Insecure(t *testing.T) { testServer(t, "tcp", false, false) } -func TestServer_Unix_Secure(t *testing.T) { testServer(t, "unix", true, false) } -func TestServer_TCP_Secure(t *testing.T) { testServer(t, "tcp", true, false) } -func TestServer_Unix_Insecure_DelayTx(t *testing.T) { testServer(t, "unix", false, true) } -func TestServer_TCP_Insecure_DelayTx(t *testing.T) { testServer(t, "tcp", false, true) } -func TestServer_Unix_Secure_DelayTx(t *testing.T) { testServer(t, "unix", true, true) } -func TestServer_TCP_Secure_DelayTx(t *testing.T) { testServer(t, "tcp", true, true) } +/* dummyServerHandler is a helper struct */ +type dummyServerHandler struct { + t *testing.T + output chan<- []byte +} -func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { - lg := zaptest.NewLogger(t) - srcAddr, dstAddr := newUnixAddr(), newUnixAddr() - if scheme == "tcp" { - ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{}) - srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String() - ln1.Close() - ln2.Close() - } else { - defer func() { - os.RemoveAll(srcAddr) - os.RemoveAll(dstAddr) - }() - } - tlsInfo := createTLSInfo(lg, secure) - ln := listen(t, scheme, dstAddr, tlsInfo) - defer ln.Close() +// ServeHTTP read the request body and write back to the response object +func (sh *dummyServerHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + defer req.Body.Close() + resp.WriteHeader(200) - cfg := ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - } - if secure { - cfg.TLSInfo = tlsInfo + if data, err := io.ReadAll(req.Body); err != nil { + sh.t.Fatal(err) + } else { + sh.output <- data } - p := NewServer(cfg) - - waitForServer(t, p) +} - defer p.Close() +func prepare(t *testing.T, serverIsClosed bool) (chan []byte, chan struct{}, Server, *http.Server, func(data []byte)) { + lg := zaptest.NewLogger(t) + scheme := "tcp" + L7Scheme := "http" - data1 := []byte("Hello World!") - donec, writec := make(chan struct{}), make(chan []byte) + // we always send the traffic to destination with HTTPS + // this will force the CONNECT header to be sent first + tlsInfo := createTLSInfo(lg) - go func() { - defer close(donec) - for data := range writec { - send(t, data, scheme, srcAddr, tlsInfo) - } - }() + ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{}) + forwardProxyAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String() + ln1.Close() + ln2.Close() recvc := make(chan []byte, 1) - go func() { - for i := 0; i < 2; i++ { - recvc <- receive(t, ln) - } - }() - - writec <- data1 - now := time.Now() - if d := <-recvc; !bytes.Equal(data1, d) { - close(writec) - t.Fatalf("expected %q, got %q", string(data1), string(d)) + httpServer := &http.Server{ + Handler: &dummyServerHandler{ + t: t, + output: recvc, + }, } - took1 := time.Since(now) - t.Logf("took %v with no latency", took1) + go startHTTPServer(scheme, dstAddr, tlsInfo, httpServer) - lat, rv := 50*time.Millisecond, 5*time.Millisecond - if delayTx { - p.DelayTx(lat, rv) + // we connect to the proxy without TLS + proxyURL := url.URL{Scheme: L7Scheme, Host: forwardProxyAddr} + cfg := ServerConfig{ + Logger: lg, + Listen: proxyURL, } + proxyServer := NewServer(cfg) + waitForServer(t, proxyServer) - data2 := []byte("new data") - writec <- data2 - now = time.Now() - if d := <-recvc; !bytes.Equal(data2, d) { - close(writec) - t.Fatalf("expected %q, got %q", string(data2), string(d)) - } - took2 := time.Since(now) - if delayTx { - t.Logf("took %v with latency %v+-%v", took2, lat, rv) + // setup forward proxy + t.Setenv("E2E_TEST_FORWARD_PROXY_IP", proxyURL.String()) + t.Logf("Proxy URL %s", proxyURL.String()) + + donec := make(chan struct{}) + + var tp *http.Transport + var err error + if !tlsInfo.Empty() { + tp, err = transport.NewTransport(tlsInfo, 1*time.Second) } else { - t.Logf("took %v with no latency", took2) + tp, err = transport.NewTransport(tlsInfo, 1*time.Second) } - - if delayTx { - p.UndelayTx() - if took2 < lat-rv { - close(writec) - t.Fatalf("expected took2 %v (with latency) > delay: %v", took2, lat-rv) - } + if err != nil { + t.Fatal(err) } + tp.IdleConnTimeout = 100 * time.Microsecond - close(writec) - select { - case <-donec: - case <-time.After(3 * time.Second): - t.Fatal("took too long to write") + sendData := func(data []byte) { + send(tp, t, data, scheme, dstAddr, tlsInfo, serverIsClosed) } - select { - case <-p.Done(): - t.Fatal("unexpected done") - case err := <-p.Error(): - t.Fatal(err) - default: - } + return recvc, donec, proxyServer, httpServer, sendData +} - if err := p.Close(); err != nil { +func destroy(t *testing.T, donec chan struct{}, proxyServer Server, serverIsClosed bool, httpServer *http.Server) { + if err := httpServer.Shutdown(context.Background()); err != nil { t.Fatal(err) } select { - case <-p.Done(): - case err := <-p.Error(): - if !strings.HasPrefix(err.Error(), "accept ") && - !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Fatal(err) - } + case <-donec: case <-time.After(3 * time.Second): - t.Fatal("took too long to close") + t.Fatal("took too long to write") } -} -func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo { - if secure { - return transport.TLSInfo{ - KeyFile: "../../tests/fixtures/server.key.insecure", - CertFile: "../../tests/fixtures/server.crt", - TrustedCAFile: "../../tests/fixtures/ca.crt", - ClientCertAuth: true, - Logger: lg, + if !serverIsClosed { + select { + case <-proxyServer.Done(): + t.Fatal("unexpected done") + case err := <-proxyServer.Error(): + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Fatal(err) + } + default: } - } - return transport.TLSInfo{Logger: lg} -} - -func TestServer_ModifyTx_corrupt(t *testing.T) { - lg := zaptest.NewLogger(t) - scheme := "unix" - srcAddr, dstAddr := newUnixAddr(), newUnixAddr() - defer func() { - os.RemoveAll(srcAddr) - os.RemoveAll(dstAddr) - }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) - defer ln.Close() - - p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - }) - - waitForServer(t, p) - defer p.Close() + if err := proxyServer.Close(); err != nil { + t.Fatal(err) + } - p.ModifyTx(func(d []byte) []byte { - d[len(d)/2]++ - return d - }) - data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); bytes.Equal(d, data) { - t.Fatalf("expected corrupted data, got %q", string(d)) + select { + case <-proxyServer.Done(): + case err := <-proxyServer.Error(): + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Fatal(err) + } + case <-time.After(3 * time.Second): + t.Fatal("took too long to close") + } } +} - p.UnmodifyTx() - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected uncorrupted data, got %q", string(d)) +func createTLSInfo(lg *zap.Logger) transport.TLSInfo { + return transport.TLSInfo{ + KeyFile: "../../tests/fixtures/server.key.insecure", + CertFile: "../../tests/fixtures/server.crt", + TrustedCAFile: "../../tests/fixtures/ca.crt", + ClientCertAuth: true, + Logger: lg, } } -func TestServer_ModifyTx_packet_loss(t *testing.T) { - lg := zaptest.NewLogger(t) - scheme := "unix" - srcAddr, dstAddr := newUnixAddr(), newUnixAddr() - defer func() { - os.RemoveAll(srcAddr) - os.RemoveAll(dstAddr) - }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) - defer ln.Close() - - p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - }) - - waitForServer(t, p) - - defer p.Close() - - // 50% packet loss - p.ModifyTx(func(d []byte) []byte { - half := len(d) / 2 - return d[:half:half] - }) - data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); bytes.Equal(d, data) { - t.Fatalf("expected corrupted data, got %q", string(d)) +func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) { + var err error + if !tlsInfo.Empty() { + ln, err = transport.NewListener(addr, scheme, &tlsInfo) + } else { + ln, err = net.Listen(scheme, addr) } - - p.UnmodifyTx() - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected uncorrupted data, got %q", string(d)) + if err != nil { + t.Fatal(err) } + return ln } -func TestServer_BlackholeTx(t *testing.T) { - lg := zaptest.NewLogger(t) - scheme := "unix" - srcAddr, dstAddr := newUnixAddr(), newUnixAddr() - defer func() { - os.RemoveAll(srcAddr) - os.RemoveAll(dstAddr) - }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) - defer ln.Close() - - p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - }) - - waitForServer(t, p) - - defer p.Close() - - p.BlackholeTx() - - data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - - recvc := make(chan []byte, 1) - go func() { - recvc <- receive(t, ln) - }() +func startHTTPServer(scheme, addr string, tlsInfo transport.TLSInfo, httpServer *http.Server) { + var err error + var ln net.Listener - select { - case d := <-recvc: - t.Fatalf("unexpected data receive %q during blackhole", string(d)) - case <-time.After(200 * time.Millisecond): + ln, err = net.Listen(scheme, addr) + if err != nil { + log.Fatal(err) } - p.UnblackholeTx() - - // expect different data, old data dropped - data[0]++ - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - - select { - case d := <-recvc: - if !bytes.Equal(data, d) { - t.Fatalf("expected %q, got %q", string(data), string(d)) - } - case <-time.After(2 * time.Second): - t.Fatal("took too long to receive after unblackhole") + log.Println("HTTP Server started on", addr) + if err := httpServer.ServeTLS(ln, tlsInfo.CertFile, tlsInfo.KeyFile); err != http.ErrServerClosed { + // always returns error. ErrServerClosed on graceful close + log.Fatalf("startHTTPServer ServeTLS(): %v", err) } } -func TestServer_Shutdown(t *testing.T) { - lg := zaptest.NewLogger(t) - scheme := "unix" - srcAddr, dstAddr := newUnixAddr(), newUnixAddr() +func send(tp *http.Transport, t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo, serverIsClosed bool) { defer func() { - os.RemoveAll(srcAddr) - os.RemoveAll(dstAddr) + tp.CloseIdleConnections() }() - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) - defer ln.Close() - - p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - }) - - waitForServer(t, p) - - defer p.Close() - - s, _ := p.(*server) - s.listener.Close() - time.Sleep(200 * time.Millisecond) - data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected %q, got %q", string(data), string(d)) + // If you call Dial(), you will get a Conn that you can write the byte stream directly + // If you call RoundTrip(), you will get a connection managed for you, but you need to send valid HTTP request + dataReader := bytes.NewReader(data) + protocolScheme := scheme + if scheme == "tcp" { + if !tlsInfo.Empty() { + protocolScheme = "https" + } else { + panic("only https is supported") + } + } else { + panic("scheme not supported") } -} - -func TestServer_ShutdownListener(t *testing.T) { - lg := zaptest.NewLogger(t) - scheme := "unix" - srcAddr, dstAddr := newUnixAddr(), newUnixAddr() - defer func() { - os.RemoveAll(srcAddr) - os.RemoveAll(dstAddr) - }() - - ln := listen(t, scheme, dstAddr, transport.TLSInfo{}) - defer ln.Close() - - p := NewServer(ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - }) - - waitForServer(t, p) - - defer p.Close() - - // shut down destination - ln.Close() - time.Sleep(200 * time.Millisecond) - - ln = listen(t, scheme, dstAddr, transport.TLSInfo{}) - defer ln.Close() - - data := []byte("Hello World!") - send(t, data, scheme, srcAddr, transport.TLSInfo{}) - if d := receive(t, ln); !bytes.Equal(d, data) { - t.Fatalf("expected %q, got %q", string(data), string(d)) + rawURL := url.URL{ + Scheme: protocolScheme, + Host: addr, } -} - -func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) } -func TestServerHTTP_Secure_DelayTx(t *testing.T) { testServerHTTP(t, true, true) } -func TestServerHTTP_Insecure_DelayRx(t *testing.T) { testServerHTTP(t, false, false) } -func TestServerHTTP_Secure_DelayRx(t *testing.T) { testServerHTTP(t, true, false) } -func testServerHTTP(t *testing.T, secure, delayTx bool) { - lg := zaptest.NewLogger(t) - scheme := "tcp" - ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{}) - srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String() - ln1.Close() - ln2.Close() - mux := http.NewServeMux() - mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) { - d, err := io.ReadAll(req.Body) - req.Body.Close() - if err != nil { - t.Fatal(err) - } - if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil { - t.Fatal(err) + req, err := http.NewRequest("POST", rawURL.String(), dataReader) + if err != nil { + t.Fatal(err) + } + res, err := tp.RoundTrip(req) + if err != nil { + if strings.Contains(err.Error(), "TLS handshake timeout") { + t.Logf("TLS handshake timeout") + return } - }) - tlsInfo := createTLSInfo(lg, secure) - var tlsConfig *tls.Config - if secure { - _, err := tlsInfo.ServerConfig() - if err != nil { - t.Fatal(err) + if serverIsClosed { + // when the proxy server is closed before sending, we will get this error message + if strings.Contains(err.Error(), "connect: connection refused") { + t.Logf("connect: connection refused") + return + } } + panic(err) } - srv := &http.Server{ - Addr: dstAddr, - Handler: mux, - TLSConfig: tlsConfig, - ErrorLog: log.New(io.Discard, "net/http", 0), - } - - donec := make(chan struct{}) defer func() { - srv.Close() - <-donec - }() - go func() { - if !secure { - srv.ListenAndServe() - } else { - srv.ListenAndServeTLS(tlsInfo.CertFile, tlsInfo.KeyFile) + if err := res.Body.Close(); err != nil { + panic(err) } - defer close(donec) }() - time.Sleep(200 * time.Millisecond) - cfg := ServerConfig{ - Logger: lg, - From: url.URL{Scheme: scheme, Host: srcAddr}, - To: url.URL{Scheme: scheme, Host: dstAddr}, - } - if secure { - cfg.TLSInfo = tlsInfo + if res.StatusCode != 200 { + t.Fatalf("status code not 200") } - p := NewServer(cfg) - - waitForServer(t, p) +} - defer func() { - lg.Info("closing Proxy server...") - p.Close() - lg.Info("closed Proxy server.") - }() +// Waits until a proxy is ready to serve. +// Aborts test on proxy start-up error. +func waitForServer(t *testing.T, s Server) { + select { + case <-s.Ready(): + case err := <-s.Error(): + t.Fatal(err) + } +} - data := "Hello World!" +func TestServer_TCP(t *testing.T) { testServer(t, false, false) } +func TestServer_TCP_DelayTx(t *testing.T) { testServer(t, true, false) } +func TestServer_TCP_DelayRx(t *testing.T) { testServer(t, false, true) } +func testServer(t *testing.T, delayTx bool, delayRx bool) { + recvc, donec, proxyServer, httpServer, sendData := prepare(t, false) + defer destroy(t, donec, proxyServer, false, httpServer) + defer close(donec) - var resp *http.Response - var err error + data1 := []byte("Hello World!") + sendData(data1) now := time.Now() - if secure { - tp, terr := transport.NewTransport(tlsInfo, 3*time.Second) - assert.NoError(t, terr) - cli := &http.Client{Transport: tp} - resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data)) - defer cli.CloseIdleConnections() - defer tp.CloseIdleConnections() - } else { - resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data)) - defer http.DefaultClient.CloseIdleConnections() - } - assert.NoError(t, err) - d, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) + if d := <-recvc; !bytes.Equal(data1, d) { + t.Fatalf("expected %q, got %q", string(data1), string(d)) } - resp.Body.Close() took1 := time.Since(now) t.Logf("took %v with no latency", took1) - rs1 := string(d) - exp := fmt.Sprintf("%q(confirmed)", data) - if rs1 != exp { - t.Fatalf("got %q, expected %q", rs1, exp) - } - - lat, rv := 100*time.Millisecond, 10*time.Millisecond + lat, rv := 50*time.Millisecond, 5*time.Millisecond if delayTx { - p.DelayTx(lat, rv) - defer p.UndelayTx() - } else { - p.DelayRx(lat, rv) - defer p.UndelayRx() + proxyServer.DelayTx(lat, rv) + } + if delayRx { + proxyServer.DelayRx(lat, rv) } + data2 := []byte("new data") now = time.Now() - if secure { - tp, terr := transport.NewTransport(tlsInfo, 3*time.Second) - if terr != nil { - t.Fatal(terr) - } - cli := &http.Client{Transport: tp} - resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data)) - defer cli.CloseIdleConnections() - defer tp.CloseIdleConnections() - } else { - resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data)) - defer http.DefaultClient.CloseIdleConnections() - } - if err != nil { - t.Fatal(err) - } - d, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) + sendData(data2) + if d := <-recvc; !bytes.Equal(data2, d) { + t.Fatalf("expected %q, got %q", string(data2), string(d)) } - resp.Body.Close() took2 := time.Since(now) - t.Logf("took %v with latency %v±%v", took2, lat, rv) + if delayTx { + t.Logf("took %v with latency %v+-%v", took2, lat, rv) + } else { + t.Logf("took %v with no latency", took2) + } - rs2 := string(d) - if rs2 != exp { - t.Fatalf("got %q, expected %q", rs2, exp) + if delayTx { + proxyServer.UndelayTx() + if took2 < lat-rv { + t.Fatalf("[delayTx] expected took2 %v (with latency) > delay: %v", took2, lat-rv) + } } - if took1 > took2 { - t.Fatalf("expected took1 %v < took2 %v", took1, took2) + if delayRx { + proxyServer.UndelayRx() + if took2 < lat-rv { + t.Fatalf("[delayRx] expected took2 %v (with latency) > delay: %v", took2, lat-rv) + } } } -func newUnixAddr() string { - now := time.Now().UnixNano() - addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000)) - os.RemoveAll(addr) - return addr -} +func TestServer_BlackholeTx(t *testing.T) { + recvc, donec, proxyServer, httpServer, sendData := prepare(t, false) + defer destroy(t, donec, proxyServer, false, httpServer) + defer close(donec) -func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) { - var err error - if !tlsInfo.Empty() { - ln, err = transport.NewListener(addr, scheme, &tlsInfo) - } else { - ln, err = net.Listen(scheme, addr) + // before enabling blacklhole + data := []byte("Hello World!") + sendData(data) + if d := <-recvc; !bytes.Equal(data, d) { + t.Fatalf("expected %q, got %q", string(data), string(d)) } - if err != nil { - t.Fatal(err) + + // enable blackhole + // note that the transport is set to use 10s for TLSHandshakeTimeout, so + // this test will require at least 10s to execute, since send() is a + // blocking call thus we need to wait for ssl handshake to timeout + proxyServer.BlackholeTx() + + sendData(data) + select { + case d := <-recvc: + t.Fatalf("unexpected data receive %q during blackhole", string(d)) + case <-time.After(200 * time.Millisecond): } - return ln -} -func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) { - var out net.Conn - var err error - if !tlsInfo.Empty() { - tp, terr := transport.NewTransport(tlsInfo, 3*time.Second) - if terr != nil { - t.Fatal(terr) + proxyServer.UnblackholeTx() + + // disable blackhole + // TODO: figure out why HTTPS won't attempt to reconnect when the blackhole is disabled + + // expect different data, old data dropped + data[0]++ + sendData(data) + select { + case d := <-recvc: + if !bytes.Equal(data, d) { + t.Fatalf("expected %q, got %q", string(data), string(d)) } - out, err = tp.DialContext(context.Background(), scheme, addr) - } else { - out, err = net.Dial(scheme, addr) - } - if err != nil { - t.Fatal(err) - } - if _, err = out.Write(data); err != nil { - t.Fatal(err) - } - if err = out.Close(); err != nil { - t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("took too long to receive after unblackhole") } } -func receive(t *testing.T, ln net.Listener) (data []byte) { - buf := bytes.NewBuffer(make([]byte, 0, 1024)) - for { - in, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - var n int64 - n, err = buf.ReadFrom(in) - if err != nil { - t.Fatal(err) - } - if n > 0 { - break - } +func TestServer_Shutdown(t *testing.T) { + recvc, donec, proxyServer, httpServer, sendData := prepare(t, true) + defer destroy(t, donec, proxyServer, true, httpServer) + defer close(donec) + + s, _ := proxyServer.(*server) + if err := s.Close(); err != nil { + t.Fatal(err) } - return buf.Bytes() -} + time.Sleep(200 * time.Millisecond) + + data := []byte("Hello World!") + sendData(data) -// Waits until a proxy is ready to serve. -// Aborts test on proxy start-up error. -func waitForServer(t *testing.T, s Server) { select { - case <-s.Ready(): - case err := <-s.Error(): - t.Fatal(err) + case d := <-recvc: + if bytes.Equal(data, d) { + t.Fatalf("expected nothing, got %q", string(d)) + } + case <-time.After(2 * time.Second): + t.Log("nothing was received, proxy server seems to be closed so no traffic is forwarded") } } diff --git a/tests/e2e/blackhole_test.go b/tests/e2e/blackhole_test.go new file mode 100644 index 00000000000..b617d0d4897 --- /dev/null +++ b/tests/e2e/blackhole_test.go @@ -0,0 +1,105 @@ +// Copyright 2024 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !cluster_proxy + +package e2e + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.etcd.io/etcd/tests/v3/framework/e2e" +) + +func TestBlackholeByMockingPartitionLeader(t *testing.T) { + blackholeTestByMockingPartition(t, 3, true) +} + +func TestBlackholeByMockingPartitionFollower(t *testing.T) { + blackholeTestByMockingPartition(t, 3, false) +} + +func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLeader bool) { + e2e.BeforeTest(t) + + t.Logf("Create an etcd cluster with %d member\n", clusterSize) + epc, err := e2e.NewEtcdProcessCluster(context.TODO(), t, + e2e.WithClusterSize(clusterSize), + e2e.WithSnapshotCount(10), + e2e.WithSnapshotCatchUpEntries(10), + e2e.WithIsPeerTLS(true), + e2e.WithPeerProxy(true), + ) + require.NoError(t, err, "failed to start etcd cluster: %v", err) + defer func() { + require.NoError(t, epc.Close(), "failed to close etcd cluster") + }() + + leaderID := epc.WaitLeader(t) + mockPartitionNodeIndex := leaderID + if !partitionLeader { + mockPartitionNodeIndex = (leaderID + 1) % (clusterSize) + } + partitionedMember := epc.Procs[mockPartitionNodeIndex] + // Mock partition + t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name) + epc.BlackholePeer(partitionedMember) + + t.Logf("Wait 1s for any open connections to expire") + time.Sleep(1 * time.Second) + + t.Logf("Wait for new leader election with remaining members") + leaderEPC := epc.Procs[waitLeader(t, epc, mockPartitionNodeIndex)] + t.Log("Writing 20 keys to the cluster (more than SnapshotCount entries to trigger at least a snapshot.)") + writeKVs(t, leaderEPC.Etcdctl(), 0, 20) + e2e.AssertProcessLogs(t, leaderEPC, "saved snapshot") + + t.Log("Verifying the partitionedMember is missing new writes") + assertRevision(t, leaderEPC, 21) + assertRevision(t, partitionedMember, 1) + + // Wait for some time to restore the network + time.Sleep(1 * time.Second) + t.Logf("Unblackholing traffic from and to member %q", partitionedMember.Config().Name) + epc.UnblackholePeer(partitionedMember) + + leaderEPC = epc.Procs[epc.WaitLeader(t)] + time.Sleep(1 * time.Second) + assertRevision(t, leaderEPC, 21) + assertRevision(t, partitionedMember, 21) +} + +func waitLeader(t testing.TB, epc *e2e.EtcdProcessCluster, excludeNode int) int { + var membs []e2e.EtcdProcess + for i := 0; i < len(epc.Procs); i++ { + if i == excludeNode { + continue + } + membs = append(membs, epc.Procs[i]) + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + return epc.WaitMembersForLeader(ctx, t, membs) +} + +func assertRevision(t testing.TB, member e2e.EtcdProcess, expectedRevision int64) { + responses, err := member.Etcdctl().Status(context.TODO()) + require.NoError(t, err) + assert.Equal(t, expectedRevision, responses[0].Header.Revision, "revision mismatch") +} diff --git a/tests/e2e/http_health_check_test.go b/tests/e2e/http_health_check_test.go index 8aa2694344f..baa5ad81110 100644 --- a/tests/e2e/http_health_check_test.go +++ b/tests/e2e/http_health_check_test.go @@ -384,10 +384,10 @@ func triggerSlowApply(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCl func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) { member := clus.Procs[0] - proxy := member.PeerProxy() + forwardProxy := member.PeerForwardProxy() t.Logf("Blackholing traffic from and to member %q", member.Config().Name) - proxy.BlackholeTx() - proxy.BlackholeRx() + forwardProxy.BlackholeTx() + forwardProxy.BlackholeRx() } func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) { diff --git a/tests/framework/e2e/cluster.go b/tests/framework/e2e/cluster.go index 23d422aa313..b9b03d9aa07 100644 --- a/tests/framework/e2e/cluster.go +++ b/tests/framework/e2e/cluster.go @@ -513,10 +513,10 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in var curl string port := cfg.BasePort + 5*i clientPort := port - peerPort := port + 1 + peerPort := port + 1 // the port that the peer actually listens on metricsPort := port + 2 - peer2Port := port + 3 - clientHTTPPort := port + 4 + clientHTTPPort := port + 3 + forwardProxyPort := port + 4 if cfg.Client.ConnectionType == ClientTLSAndNonTLS { curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS) @@ -528,17 +528,23 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} - var proxyCfg *proxy.ServerConfig + var forwardProxyCfg *proxy.ServerConfig if cfg.PeerProxy { if !cfg.IsPeerTLS { panic("Can't use peer proxy without peer TLS as it can result in malformed packets") } - peerAdvertiseURL.Host = fmt.Sprintf("localhost:%d", peer2Port) - proxyCfg = &proxy.ServerConfig{ + + // setup forward proxy + forwardProxyURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", forwardProxyPort)} + forwardProxyCfg = &proxy.ServerConfig{ Logger: zap.NewNop(), - To: peerListenURL, - From: peerAdvertiseURL, + Listen: forwardProxyURL, + } + + if cfg.EnvVars == nil { + cfg.EnvVars = make(map[string]string) } + cfg.EnvVars["E2E_TEST_FORWARD_PROXY_IP"] = fmt.Sprintf("http://127.0.0.1:%d", forwardProxyPort) } name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i) @@ -660,7 +666,7 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in InitialToken: cfg.ServerConfig.InitialClusterToken, GoFailPort: gofailPort, GoFailClientTimeout: cfg.GoFailClientTimeout, - Proxy: proxyCfg, + ForwardProxy: forwardProxyCfg, LazyFSEnabled: cfg.LazyFSEnabled, } } @@ -910,6 +916,38 @@ func (epc *EtcdProcessCluster) Restart(ctx context.Context) error { return epc.start(func(ep EtcdProcess) error { return ep.Restart(ctx) }) } +func (epc *EtcdProcessCluster) BlackholePeer(blackholePeer EtcdProcess) error { + blackholePeer.PeerForwardProxy().BlackholeRx() + blackholePeer.PeerForwardProxy().BlackholeTx() + + for _, peer := range epc.Procs { + if peer.Config().Name == blackholePeer.Config().Name { + continue + } + + peer.PeerForwardProxy().BlackholePeerRx(blackholePeer.Config().PeerURL) + peer.PeerForwardProxy().BlackholePeerTx(blackholePeer.Config().PeerURL) + } + + return nil +} + +func (epc *EtcdProcessCluster) UnblackholePeer(blackholePeer EtcdProcess) error { + blackholePeer.PeerForwardProxy().UnblackholeRx() + blackholePeer.PeerForwardProxy().UnblackholeTx() + + for _, peer := range epc.Procs { + if peer.Config().Name == blackholePeer.Config().Name { + continue + } + + peer.PeerForwardProxy().UnblackholePeerRx(blackholePeer.Config().PeerURL) + peer.PeerForwardProxy().UnblackholePeerTx(blackholePeer.Config().PeerURL) + } + + return nil +} + func (epc *EtcdProcessCluster) start(f func(ep EtcdProcess) error) error { readyC := make(chan error, len(epc.Procs)) for i := range epc.Procs { diff --git a/tests/framework/e2e/etcd_process.go b/tests/framework/e2e/etcd_process.go index b55ef9e90a6..55ab7053686 100644 --- a/tests/framework/e2e/etcd_process.go +++ b/tests/framework/e2e/etcd_process.go @@ -55,7 +55,7 @@ type EtcdProcess interface { Stop() error Close() error Config() *EtcdServerProcessConfig - PeerProxy() proxy.Server + PeerForwardProxy() proxy.Server Failpoints() *BinaryFailpoints LazyFS() *LazyFS Logs() LogsExpect @@ -69,12 +69,12 @@ type LogsExpect interface { } type EtcdServerProcess struct { - cfg *EtcdServerProcessConfig - proc *expect.ExpectProcess - proxy proxy.Server - lazyfs *LazyFS - failpoints *BinaryFailpoints - donec chan struct{} // closed when Interact() terminates + cfg *EtcdServerProcessConfig + proc *expect.ExpectProcess + forwardProxy proxy.Server + lazyfs *LazyFS + failpoints *BinaryFailpoints + donec chan struct{} // closed when Interact() terminates } type EtcdServerProcessConfig struct { @@ -101,7 +101,7 @@ type EtcdServerProcessConfig struct { GoFailClientTimeout time.Duration LazyFSEnabled bool - Proxy *proxy.ServerConfig + ForwardProxy *proxy.ServerConfig } func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) { @@ -151,12 +151,13 @@ func (ep *EtcdServerProcess) Start(ctx context.Context) error { if ep.proc != nil { panic("already started") } - if ep.cfg.Proxy != nil && ep.proxy == nil { - ep.cfg.lg.Info("starting proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.Proxy.From.String()), zap.String("to", ep.cfg.Proxy.To.String())) - ep.proxy = proxy.NewServer(*ep.cfg.Proxy) + + if ep.cfg.ForwardProxy != nil && ep.forwardProxy == nil { + ep.cfg.lg.Info("starting forward proxy...", zap.String("name", ep.cfg.Name), zap.String("listen on", ep.cfg.ForwardProxy.Listen.String())) + ep.forwardProxy = proxy.NewServer(*ep.cfg.ForwardProxy) select { - case <-ep.proxy.Ready(): - case err := <-ep.proxy.Error(): + case <-ep.forwardProxy.Ready(): + case err := <-ep.forwardProxy.Error(): return err } } @@ -221,10 +222,10 @@ func (ep *EtcdServerProcess) Stop() (err error) { } } ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name)) - if ep.proxy != nil { - ep.cfg.lg.Info("stopping proxy...", zap.String("name", ep.cfg.Name)) - err = ep.proxy.Close() - ep.proxy = nil + if ep.forwardProxy != nil { + ep.cfg.lg.Info("stopping forward proxy...", zap.String("name", ep.cfg.Name)) + err = ep.forwardProxy.Close() + ep.forwardProxy = nil if err != nil { return err } @@ -330,8 +331,8 @@ func AssertProcessLogs(t *testing.T, ep EtcdProcess, expectLog string) { } } -func (ep *EtcdServerProcess) PeerProxy() proxy.Server { - return ep.proxy +func (ep *EtcdServerProcess) PeerForwardProxy() proxy.Server { + return ep.forwardProxy } func (ep *EtcdServerProcess) LazyFS() *LazyFS { diff --git a/tests/robustness/failpoint/network.go b/tests/robustness/failpoint/network.go index 27504c396b9..67f8520dfe9 100644 --- a/tests/robustness/failpoint/network.go +++ b/tests/robustness/failpoint/network.go @@ -63,23 +63,17 @@ func (tb triggerBlackhole) Available(config e2e.EtcdProcessClusterConfig, proces if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) { return false } - return config.ClusterSize > 1 && process.PeerProxy() != nil + return config.ClusterSize > 1 && process.PeerForwardProxy() != nil } func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error { - proxy := member.PeerProxy() - - // Blackholing will cause peers to not be able to use streamWriters registered with member - // but peer traffic is still possible because member has 'pipeline' with peers - // TODO: find a way to stop all traffic t.Logf("Blackholing traffic from and to member %q", member.Config().Name) - proxy.BlackholeTx() - proxy.BlackholeRx() + clus.BlackholePeer(member) defer func() { t.Logf("Traffic restored from and to member %q", member.Config().Name) - proxy.UnblackholeTx() - proxy.UnblackholeRx() + clus.UnblackholePeer(member) }() + if shouldWaitTillSnapshot { return waitTillSnapshot(ctx, t, clus, member) } @@ -164,15 +158,15 @@ type delayPeerNetworkFailpoint struct { func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { member := clus.Procs[rand.Int()%len(clus.Procs)] - proxy := member.PeerProxy() + forwardProxy := member.PeerForwardProxy() - proxy.DelayRx(f.baseLatency, f.randomizedLatency) - proxy.DelayTx(f.baseLatency, f.randomizedLatency) + forwardProxy.DelayRx(f.baseLatency, f.randomizedLatency) + forwardProxy.DelayTx(f.baseLatency, f.randomizedLatency) lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency)) time.Sleep(f.duration) lg.Info("Traffic delay removed", zap.String("member", member.Config().Name)) - proxy.UndelayRx() - proxy.UndelayTx() + forwardProxy.UndelayRx() + forwardProxy.UndelayTx() return nil, nil } @@ -181,7 +175,7 @@ func (f delayPeerNetworkFailpoint) Name() string { } func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool { - return config.ClusterSize > 1 && clus.PeerProxy() != nil + return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil } type dropPeerNetworkFailpoint struct { @@ -191,15 +185,15 @@ type dropPeerNetworkFailpoint struct { func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { member := clus.Procs[rand.Int()%len(clus.Procs)] - proxy := member.PeerProxy() + forwardProxy := member.PeerForwardProxy() - proxy.ModifyRx(f.modifyPacket) - proxy.ModifyTx(f.modifyPacket) + forwardProxy.ModifyRx(f.modifyPacket) + forwardProxy.ModifyTx(f.modifyPacket) lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent)) time.Sleep(f.duration) lg.Info("Traffic drop removed", zap.String("member", member.Config().Name)) - proxy.UnmodifyRx() - proxy.UnmodifyTx() + forwardProxy.UnmodifyRx() + forwardProxy.UnmodifyTx() return nil, nil } @@ -215,5 +209,5 @@ func (f dropPeerNetworkFailpoint) Name() string { } func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool { - return config.ClusterSize > 1 && clus.PeerProxy() != nil + return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil }