Skip to content

Commit

Permalink
StreamTranslator and FallbackExecutor for WebSockets
Browse files Browse the repository at this point in the history
Kubernetes-commit: 168998e87bfd49a1b0bc6402761fafd5ace3bb3b
  • Loading branch information
seans3 authored and k8s-publishing-bot committed Jul 7, 2023
1 parent 1e138bd commit 20301d1
Show file tree
Hide file tree
Showing 9 changed files with 401 additions and 29 deletions.
21 changes: 21 additions & 0 deletions pkg/util/httpstream/httpstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package httpstream

import (
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -95,6 +96,26 @@ type Stream interface {
Identifier() uint32
}

// UpgradeFailureError encapsulates the cause for why the streaming
// upgrade request failed. Implements error interface.
type UpgradeFailureError struct {
Cause error
}

func (u *UpgradeFailureError) Error() string {
return fmt.Sprintf("unable to upgrade streaming request: %s", u.Cause)
}

// IsUpgradeFailure returns true if the passed error is (or wrapped error contains)
// the UpgradeFailureError.
func IsUpgradeFailure(err error) bool {
if err == nil {
return false
}
var upgradeErr *UpgradeFailureError
return errors.As(err, &upgradeErr)
}

// IsUpgradeRequest returns true if the given request is a connection upgrade request
func IsUpgradeRequest(req *http.Request) bool {
for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] {
Expand Down
39 changes: 39 additions & 0 deletions pkg/util/httpstream/httpstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package httpstream

import (
"errors"
"fmt"
"net/http"
"reflect"
"testing"
Expand Down Expand Up @@ -129,3 +131,40 @@ func TestHandshake(t *testing.T) {
}
}
}

func TestIsUpgradeFailureError(t *testing.T) {
testCases := map[string]struct {
err error
expected bool
}{
"nil error should return false": {
err: nil,
expected: false,
},
"Non-upgrade error should return false": {
err: fmt.Errorf("this is not an upgrade error"),
expected: false,
},
"UpgradeFailure error should return true": {
err: &UpgradeFailureError{},
expected: true,
},
"Wrapped Non-UpgradeFailure error should return false": {
err: fmt.Errorf("%s: %w", "first error", errors.New("Non-upgrade error")),
expected: false,
},
"Wrapped UpgradeFailure error should return true": {
err: fmt.Errorf("%s: %w", "first error", &UpgradeFailureError{}),
expected: true,
},
}

for name, test := range testCases {
t.Run(name, func(t *testing.T) {
actual := IsUpgradeFailure(test.err)
if test.expected != actual {
t.Errorf("expected upgrade failure %t, got %t", test.expected, actual)
}
})
}
}
55 changes: 42 additions & 13 deletions pkg/util/httpstream/spdy/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/httpstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
apiproxy "k8s.io/apimachinery/pkg/util/proxy"
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)

Expand Down Expand Up @@ -68,6 +69,10 @@ type SpdyRoundTripper struct {
// pingPeriod is a period for sending Ping frames over established
// connections.
pingPeriod time.Duration

// upgradeTransport is an optional substitute for dialing if present. This field is
// mutually exclusive with the "tlsConfig", "Dialer", and "proxier".
upgradeTransport http.RoundTripper
}

var _ utilnet.TLSClientConfigHolder = &SpdyRoundTripper{}
Expand All @@ -76,43 +81,61 @@ var _ utilnet.Dialer = &SpdyRoundTripper{}

// NewRoundTripper creates a new SpdyRoundTripper that will use the specified
// tlsConfig.
func NewRoundTripper(tlsConfig *tls.Config) *SpdyRoundTripper {
func NewRoundTripper(tlsConfig *tls.Config) (*SpdyRoundTripper, error) {
return NewRoundTripperWithConfig(RoundTripperConfig{
TLS: tlsConfig,
TLS: tlsConfig,
UpgradeTransport: nil,
})
}

// NewRoundTripperWithProxy creates a new SpdyRoundTripper that will use the
// specified tlsConfig and proxy func.
func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) *SpdyRoundTripper {
func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) (*SpdyRoundTripper, error) {
return NewRoundTripperWithConfig(RoundTripperConfig{
TLS: tlsConfig,
Proxier: proxier,
TLS: tlsConfig,
Proxier: proxier,
UpgradeTransport: nil,
})
}

// NewRoundTripperWithConfig creates a new SpdyRoundTripper with the specified
// configuration.
func NewRoundTripperWithConfig(cfg RoundTripperConfig) *SpdyRoundTripper {
// configuration. Returns an error if the SpdyRoundTripper is misconfigured.
func NewRoundTripperWithConfig(cfg RoundTripperConfig) (*SpdyRoundTripper, error) {
// Process UpgradeTransport, which is mutually exclusive to TLSConfig and Proxier.
if cfg.UpgradeTransport != nil {
if cfg.TLS != nil || cfg.Proxier != nil {
return nil, fmt.Errorf("SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier")
}
tlsConfig, err := utilnet.TLSClientConfig(cfg.UpgradeTransport)
if err != nil {
return nil, fmt.Errorf("SpdyRoundTripper: Unable to retrieve TLSConfig from UpgradeTransport: %v", err)
}
cfg.TLS = tlsConfig
}
if cfg.Proxier == nil {
cfg.Proxier = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}
return &SpdyRoundTripper{
tlsConfig: cfg.TLS,
proxier: cfg.Proxier,
pingPeriod: cfg.PingPeriod,
}
tlsConfig: cfg.TLS,
proxier: cfg.Proxier,
pingPeriod: cfg.PingPeriod,
upgradeTransport: cfg.UpgradeTransport,
}, nil
}

// RoundTripperConfig is a set of options for an SpdyRoundTripper.
type RoundTripperConfig struct {
// TLS configuration used by the round tripper.
// TLS configuration used by the round tripper if UpgradeTransport not present.
TLS *tls.Config
// Proxier is a proxy function invoked on each request. Optional.
Proxier func(*http.Request) (*url.URL, error)
// PingPeriod is a period for sending SPDY Pings on the connection.
// Optional.
PingPeriod time.Duration
// UpgradeTransport is a subtitute transport used for dialing. If set,
// this field will be used instead of "TLS" and "Proxier" for connection creation.
// Optional.
UpgradeTransport http.RoundTripper
}

// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during
Expand All @@ -123,7 +146,13 @@ func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config {

// Dial implements k8s.io/apimachinery/pkg/util/net.Dialer.
func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
conn, err := s.dial(req)
var conn net.Conn
var err error
if s.upgradeTransport != nil {
conn, err = apiproxy.DialURL(req.Context(), req.URL, s.upgradeTransport)
} else {
conn, err = s.dial(req)
}
if err != nil {
return nil, err
}
Expand Down
85 changes: 82 additions & 3 deletions pkg/util/httpstream/spdy/roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strconv"
"strings"
"testing"

"github.com/armon/go-socks5"
Expand Down Expand Up @@ -324,7 +326,10 @@ func TestRoundTripAndNewConnection(t *testing.T) {
t.Fatalf("error creating request: %s", err)
}

spdyTransport := NewRoundTripper(testCase.clientTLS)
spdyTransport, err := NewRoundTripper(testCase.clientTLS)
if err != nil {
t.Fatalf("error creating SpdyRoundTripper: %v", err)
}

var proxierCalled bool
var proxyCalledWithHost string
Expand Down Expand Up @@ -428,6 +433,74 @@ func TestRoundTripAndNewConnection(t *testing.T) {
}
}

// Tests SpdyRoundTripper constructors
func TestRoundTripConstuctor(t *testing.T) {
testCases := map[string]struct {
tlsConfig *tls.Config
proxier func(req *http.Request) (*url.URL, error)
upgradeTransport http.RoundTripper
expectedTLSConfig *tls.Config
errMsg string
}{
"Basic TLSConfig; no error": {
tlsConfig: &tls.Config{InsecureSkipVerify: true},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
upgradeTransport: nil,
},
"Basic TLSConfig and Proxier: no error": {
tlsConfig: &tls.Config{InsecureSkipVerify: true},
proxier: func(req *http.Request) (*url.URL, error) { return nil, nil },
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
upgradeTransport: nil,
},
"TLSConfig with UpgradeTransport: error": {
tlsConfig: &tls.Config{InsecureSkipVerify: true},
upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
errMsg: "SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier",
},
"Proxier with UpgradeTransport: error": {
proxier: func(req *http.Request) (*url.URL, error) { return nil, nil },
upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
errMsg: "SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier",
},
"Only UpgradeTransport: no error": {
upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
expectedTLSConfig: &tls.Config{InsecureSkipVerify: true},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
spdyRoundTripper, err := NewRoundTripperWithConfig(
RoundTripperConfig{
TLS: testCase.tlsConfig,
Proxier: testCase.proxier,
UpgradeTransport: testCase.upgradeTransport,
},
)
if testCase.errMsg != "" {
if err == nil {
t.Fatalf("expected error but received none")
}
if !strings.Contains(err.Error(), testCase.errMsg) {
t.Fatalf("expected error message (%s), got (%s)", err.Error(), testCase.errMsg)
}
}
if testCase.errMsg == "" {
if err != nil {
t.Fatalf("unexpected error received: %v", err)
}
actualTLSConfig := spdyRoundTripper.TLSClientConfig()
if !reflect.DeepEqual(testCase.expectedTLSConfig, actualTLSConfig) {
t.Errorf("expected TLSConfig (%v), got (%v)",
testCase.expectedTLSConfig, actualTLSConfig)
}
}
})
}
}

type Interceptor struct {
Authorization socks5.AuthContext
proxyCalledWithHost *string
Expand Down Expand Up @@ -544,7 +617,10 @@ func TestRoundTripSocks5AndNewConnection(t *testing.T) {
t.Fatalf("error creating request: %s", err)
}

spdyTransport := NewRoundTripper(testCase.clientTLS)
spdyTransport, err := NewRoundTripper(testCase.clientTLS)
if err != nil {
t.Fatalf("error creating SpdyRoundTripper: %v", err)
}
var proxierCalled bool
var proxyCalledWithHost string

Expand Down Expand Up @@ -704,7 +780,10 @@ func TestRoundTripPassesContextToDialer(t *testing.T) {
cancel()
req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
require.NoError(t, err)
spdyTransport := NewRoundTripper(&tls.Config{})
spdyTransport, err := NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("error creating SpdyRoundTripper: %v", err)
}
_, err = spdyTransport.Dial(req)
assert.EqualError(t, err, "dial tcp 127.0.0.1:1233: operation was canceled")
})
Expand Down
Loading

0 comments on commit 20301d1

Please sign in to comment.