diff --git a/client/pkg/transport/transport.go b/client/pkg/transport/transport.go index 91462dcdb08b..3daa71b52948 100644 --- a/client/pkg/transport/transport.go +++ b/client/pkg/transport/transport.go @@ -18,6 +18,8 @@ import ( "context" "net" "net/http" + "net/url" + "os" "strings" "time" ) @@ -31,7 +33,17 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er } t := &http.Transport{ - Proxy: http.ProxyFromEnvironment, + Proxy: func(req *http.Request) (*url.URL, error) { + // according to the comment of http.ProxyFromEnvironment: if the + // proxy URL is "localhost" (with or without a port number), + // then a nil URL and nil error will be returned. + // Thus, we need to workaround this by manually setting an + // ENV named FORWARD_PROXY and parse the URL (which is a localhost in our case) + if forwardProxy, exists := os.LookupEnv("FORWARD_PROXY"); exists { + return url.Parse(forwardProxy) + } + return http.ProxyFromEnvironment(req) + }, DialContext: (&net.Dialer{ Timeout: dialtimeoutd, // value taken from http.DefaultTransport diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 6d7931b4e33a..c90f7a53fe43 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -15,6 +15,8 @@ package proxy import ( + "bufio" + "bytes" "context" "fmt" "io" @@ -130,18 +132,21 @@ type Server interface { // ServerConfig defines proxy server configuration. type ServerConfig struct { - Logger *zap.Logger - From url.URL - To url.URL - TLSInfo transport.TLSInfo - DialTimeout time.Duration - BufferSize int - RetryInterval time.Duration + Logger *zap.Logger + From url.URL + To url.URL + TLSInfo transport.TLSInfo + DialTimeout time.Duration + BufferSize int + RetryInterval time.Duration + IsForwardProxy bool } type server struct { lg *zap.Logger + isForwardProxy bool + from url.URL fromPort int to url.URL @@ -194,6 +199,8 @@ func NewServer(cfg ServerConfig) Server { s := &server{ lg: cfg.Logger, + isForwardProxy: cfg.IsForwardProxy, + from: cfg.From, to: cfg.To, @@ -216,10 +223,12 @@ func NewServer(cfg ServerConfig) Server { if err == nil { s.fromPort, _ = strconv.Atoi(fromPort) } - var toPort string - _, toPort, err = net.SplitHostPort(cfg.To.Host) - if err == nil { - s.toPort, _ = strconv.Atoi(toPort) + if !s.isForwardProxy { + var toPort string + _, toPort, err = net.SplitHostPort(cfg.To.Host) + if err == nil { + s.toPort, _ = strconv.Atoi(toPort) + } } if s.dialTimeout == 0 { @@ -239,8 +248,10 @@ func NewServer(cfg ServerConfig) Server { if strings.HasPrefix(s.from.Scheme, "http") { s.from.Scheme = "tcp" } - if strings.HasPrefix(s.to.Scheme, "http") { - s.to.Scheme = "tcp" + if !s.isForwardProxy { + if strings.HasPrefix(s.to.Scheme, "http") { + s.to.Scheme = "tcp" + } } addr := fmt.Sprintf(":%d", s.fromPort) @@ -273,7 +284,10 @@ func (s *server) From() string { } func (s *server) To() string { - return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host) + if !s.isForwardProxy { + return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host) + } + return "" } // TODO: implement packet reordering from multiple TCP connections @@ -353,6 +367,40 @@ func (s *server) listenAndServe() { continue } + parseHeaderForDestination := func() string { + // the first request should always contain a CONNECT header field + // since we set the transport to forward the traffic to the proxy + buf := make([]byte, s.bufferSize) + var data []byte + if nr1, err := in.Read(buf); err != nil { + if err == io.EOF { + panic("No data available for forward proxy to work on") + } + } else { + data = buf[:nr1] + } + + // attempt to parse for the HOST from the CONNECT request + var req *http.Request + if req, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(data))); err != nil { + panic("Failed to parse header in forward proxy") + } + + if req.Method == http.MethodConnect { + // make sure a reply is sent back to the client + connectResponse := &http.Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + } + connectResponse.Write(in) + + return req.URL.Host + } + + panic("Wrong header type to start the connection") + } + var out net.Conn if !s.tlsInfo.Empty() { var tp *http.Transport @@ -370,9 +418,19 @@ func (s *server) listenAndServe() { } continue } - out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host) + if s.isForwardProxy { + dest := parseHeaderForDestination() + out, err = tp.DialContext(ctx, "tcp", dest) + } else { + out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host) + } } else { - out, err = net.Dial(s.to.Scheme, s.to.Host) + if s.isForwardProxy { + dest := parseHeaderForDestination() + out, err = net.Dial("tcp", dest) + } else { + out, err = net.Dial(s.to.Scheme, s.to.Host) + } } if err != nil { select { diff --git a/tests/e2e/blackhole_test.go b/tests/e2e/blackhole_test.go index 68150cb8608a..00d8f9695a86 100644 --- a/tests/e2e/blackhole_test.go +++ b/tests/e2e/blackhole_test.go @@ -51,17 +51,20 @@ func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLea require.NoError(t, epc.Close(), "failed to close etcd cluster") }() - leaderId := epc.WaitLeader(t) - mockPartitionNodeIndex := leaderId + leaderID := epc.WaitLeader(t) + mockPartitionNodeIndex := leaderID if !partitionLeader { - mockPartitionNodeIndex = (leaderId + 1) % (clusterSize) + mockPartitionNodeIndex = (leaderID + 1) % (clusterSize) } partitionedMember := epc.Procs[mockPartitionNodeIndex] // Mock partition proxy := partitionedMember.PeerProxy() + forwardProxy := partitionedMember.PeerForwardProxy() t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name) proxy.BlackholeTx() proxy.BlackholeRx() + forwardProxy.BlackholeTx() + forwardProxy.BlackholeRx() t.Logf("Wait 5s for any open connections to expire") time.Sleep(5 * time.Second) @@ -81,6 +84,8 @@ func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLea t.Logf("Unblackholing traffic from and to member %q", partitionedMember.Config().Name) proxy.UnblackholeTx() proxy.UnblackholeRx() + forwardProxy.UnblackholeTx() + forwardProxy.UnblackholeRx() leaderEPC = epc.Procs[epc.WaitLeader(t)] time.Sleep(5 * time.Second) diff --git a/tests/framework/e2e/cluster.go b/tests/framework/e2e/cluster.go index 8f3a102c3059..6f9e6a848d88 100644 --- a/tests/framework/e2e/cluster.go +++ b/tests/framework/e2e/cluster.go @@ -481,12 +481,13 @@ func (cfg *EtcdProcessClusterConfig) SetInitialOrDiscovery(serverCfg *EtcdServer func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i int) *EtcdServerProcessConfig { var curls []string var curl string - port := cfg.BasePort + 5*i + port := cfg.BasePort + 6*i clientPort := port - peerPort := port + 1 + peerPort := port + 1 // the port that the peer actually listens on metricsPort := port + 2 - peer2Port := port + 3 + peer2Port := port + 3 // the port that the peer advertises clientHTTPPort := port + 4 + forwardProxyPort := port + 5 // the port of the forward proxy if cfg.Client.ConnectionType == ClientTLSAndNonTLS { curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS) @@ -499,6 +500,7 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} var proxyCfg *proxy.ServerConfig + var forwardProxyCfg *proxy.ServerConfig if cfg.PeerProxy { if !cfg.IsPeerTLS { panic("Can't use peer proxy without peer TLS as it can result in malformed packets") @@ -509,6 +511,19 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in To: peerListenURL, From: peerAdvertiseURL, } + + // setup forward proxy + forwardProxyURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", forwardProxyPort)} + forwardProxyCfg = &proxy.ServerConfig{ + Logger: zap.NewNop(), + From: forwardProxyURL, + IsForwardProxy: true, + } + + if cfg.EnvVars == nil { + cfg.EnvVars = make(map[string]string) + } + cfg.EnvVars["FORWARD_PROXY"] = fmt.Sprintf("http://127.0.0.1:%d", forwardProxyPort) } name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i) @@ -631,6 +646,7 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in GoFailPort: gofailPort, GoFailClientTimeout: cfg.GoFailClientTimeout, Proxy: proxyCfg, + ForwardProxy: forwardProxyCfg, LazyFSEnabled: cfg.LazyFSEnabled, } } diff --git a/tests/framework/e2e/etcd_process.go b/tests/framework/e2e/etcd_process.go index f9d2089a3e3b..97daabfff562 100644 --- a/tests/framework/e2e/etcd_process.go +++ b/tests/framework/e2e/etcd_process.go @@ -56,6 +56,7 @@ type EtcdProcess interface { Close() error Config() *EtcdServerProcessConfig PeerProxy() proxy.Server + PeerForwardProxy() proxy.Server Failpoints() *BinaryFailpoints LazyFS() *LazyFS Logs() LogsExpect @@ -69,12 +70,13 @@ type LogsExpect interface { } type EtcdServerProcess struct { - cfg *EtcdServerProcessConfig - proc *expect.ExpectProcess - proxy proxy.Server - lazyfs *LazyFS - failpoints *BinaryFailpoints - donec chan struct{} // closed when Interact() terminates + cfg *EtcdServerProcessConfig + proc *expect.ExpectProcess + proxy proxy.Server + forwardProxy proxy.Server + lazyfs *LazyFS + failpoints *BinaryFailpoints + donec chan struct{} // closed when Interact() terminates } type EtcdServerProcessConfig struct { @@ -102,6 +104,7 @@ type EtcdServerProcessConfig struct { LazyFSEnabled bool Proxy *proxy.ServerConfig + ForwardProxy *proxy.ServerConfig } func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) { @@ -159,6 +162,14 @@ func (ep *EtcdServerProcess) Start(ctx context.Context) error { case err := <-ep.proxy.Error(): return err } + + ep.cfg.lg.Info("starting forward proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.ForwardProxy.From.String()), zap.String("to", ep.cfg.ForwardProxy.To.String())) + ep.forwardProxy = proxy.NewServer(*ep.cfg.ForwardProxy) + select { + case <-ep.forwardProxy.Ready(): + case err := <-ep.forwardProxy.Error(): + return err + } } if ep.lazyfs != nil { ep.cfg.lg.Info("starting lazyfs...", zap.String("name", ep.cfg.Name)) @@ -222,6 +233,13 @@ func (ep *EtcdServerProcess) Stop() (err error) { } ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name)) if ep.proxy != nil { + ep.cfg.lg.Info("stopping forward proxy...", zap.String("name", ep.cfg.Name)) + err = ep.forwardProxy.Close() + ep.forwardProxy = nil + if err != nil { + return err + } + ep.cfg.lg.Info("stopping proxy...", zap.String("name", ep.cfg.Name)) err = ep.proxy.Close() ep.proxy = nil @@ -330,6 +348,10 @@ func (ep *EtcdServerProcess) PeerProxy() proxy.Server { return ep.proxy } +func (ep *EtcdServerProcess) PeerForwardProxy() proxy.Server { + return ep.forwardProxy +} + func (ep *EtcdServerProcess) LazyFS() *LazyFS { return ep.lazyfs } diff --git a/tests/robustness/failpoint/network.go b/tests/robustness/failpoint/network.go index 5d59fba3d99c..541a6db17928 100644 --- a/tests/robustness/failpoint/network.go +++ b/tests/robustness/failpoint/network.go @@ -63,18 +63,21 @@ func (tb triggerBlackhole) Available(config e2e.EtcdProcessClusterConfig, proces func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error { proxy := member.PeerProxy() + forwardProxy := member.PeerForwardProxy() - // Blackholing will cause peers to not be able to use streamWriters registered with member - // but peer traffic is still possible because member has 'pipeline' with peers - // TODO: find a way to stop all traffic t.Logf("Blackholing traffic from and to member %q", member.Config().Name) proxy.BlackholeTx() proxy.BlackholeRx() + forwardProxy.BlackholeTx() + forwardProxy.BlackholeRx() defer func() { t.Logf("Traffic restored from and to member %q", member.Config().Name) proxy.UnblackholeTx() proxy.UnblackholeRx() + forwardProxy.UnblackholeTx() + forwardProxy.UnblackholeRx() }() + if shouldWaitTillSnapshot { return waitTillSnapshot(ctx, t, clus, member) }