Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to override content type for a client #132

Merged
merged 12 commits into from
Oct 11, 2023
42 changes: 42 additions & 0 deletions _integration-tests/echo_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ func testWithEchoService(t *testing.T, serverPreferGRPCWeb bool) {
expectClientStreamOK: false,
expectBidiStreamOK: false,
},
{
targetID: "downgrading-grpc",
behindHTTP1ReverseProxy: false,
useProxy: true,
forceDowngrade: true,
customContentType: "application/grpc-web",
expectUnaryOK: true,
expectServerStreamOK: true,
expectClientStreamOK: false,
expectBidiStreamOK: false,
},
{
targetID: "downgrading-grpc",
behindHTTP1ReverseProxy: true,
useProxy: true,
forceDowngrade: true,
customContentType: "application/grpc-web",
expectUnaryOK: true,
expectServerStreamOK: true,
expectClientStreamOK: false,
expectBidiStreamOK: false,
},
{
targetID: "downgrading-grpc",
behindHTTP1ReverseProxy: true,
useProxy: true,
forceDowngrade: false,
customContentType: "application/grpc-web",
expectUnaryOK: true,
expectServerStreamOK: true,
expectClientStreamOK: false,
expectBidiStreamOK: false,
},
}

for _, c := range cases {
Expand Down Expand Up @@ -317,6 +350,7 @@ type testCase struct {
useProxy bool
useWebSocket bool
forceDowngrade bool
customContentType string

expectUnaryOK bool
expectClientStreamOK bool
Expand All @@ -343,6 +377,10 @@ func (c *testCase) Name() string {
} else {
sb.WriteString("-direct")
}

if len(c.customContentType) > 0 {
sb.WriteString("-custom-content-type")
}
return sb.String()
}

Expand Down Expand Up @@ -381,6 +419,10 @@ func (c *testCase) Run(t *testing.T, cfg *testConfig) {
}
opts = append(opts, client.UseWebSocket(c.useWebSocket), client.ForceDowngrade(c.forceDowngrade))

if len(c.customContentType) > 0 {
opts = append(opts, client.WithContentType(c.customContentType))
}

cc, err = client.ConnectViaProxy(ctx, targetAddr, nil, opts...)
} else {
cc, err = grpc.DialContext(ctx, targetAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down
13 changes: 13 additions & 0 deletions client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type connectOptions struct {
forceHTTP2 bool
forceDowngrade bool
useWebSocket bool
contentType string
}

// ConnectOption is an option that can be passed to the `ConnectViaProxy` method.
Expand Down Expand Up @@ -66,6 +67,12 @@ func ForceDowngrade(force bool) ConnectOption {
return forceDowngradeOption(force)
}

// WithContentType returns a connection option that instructs the
// client to use a custom content type for sending requests to the server.
func WithContentType(contentType string) ConnectOption {
return contentTypeOption(contentType)
}

type dialOptsOption []grpc.DialOption

func (o dialOptsOption) apply(opts *connectOptions) {
Expand Down Expand Up @@ -95,3 +102,9 @@ type forceDowngradeOption bool
func (o forceDowngradeOption) apply(opts *connectOptions) {
opts.forceDowngrade = bool(o)
}

type contentTypeOption string

func (o contentTypeOption) apply(opts *connectOptions) {
opts.contentType = string(o)
}
16 changes: 12 additions & 4 deletions client/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func writeError(w http.ResponseWriter, err error) {
w.Header().Set("Grpc-Message", grpcproto.EncodeGrpcMessage(errMsg))
}

func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool) *httputil.ReverseProxy {
func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool, contentType string) *httputil.ReverseProxy {
scheme := "https"
if insecure {
scheme = "http"
Expand All @@ -95,6 +95,14 @@ func createReverseProxy(endpoint string, transport http.RoundTripper, insecure,
req.Header.Add("Accept", "application/grpc")
}
req.Header.Add("Accept", "application/grpc-web")

if len(contentType) > 0 {
// Replacing old content type (e.g., application/grpc), to an overridden content type.
// Without removing old header, some gRPC-Web servers will not work,
// because an HTTP client will send both old and new header values.
req.Header.Set("Content-Type", contentType)
}

req.URL.Scheme = scheme
req.URL.Host = endpoint
},
Expand Down Expand Up @@ -142,12 +150,12 @@ func createTransport(tlsClientConf *tls.Config, forceHTTP2 bool, extraH2ALPNs []
return transport, nil
}

func createClientProxy(endpoint string, tlsClientConf *tls.Config, forceHTTP2, forceDowngrade bool, extraH2ALPNs []string) (*http.Server, pipeconn.DialContextFunc, error) {
func createClientProxy(endpoint string, tlsClientConf *tls.Config, forceHTTP2, forceDowngrade bool, extraH2ALPNs []string, contentType string) (*http.Server, pipeconn.DialContextFunc, error) {
transport, err := createTransport(tlsClientConf, forceHTTP2, extraH2ALPNs)
if err != nil {
return nil, nil, errors.Wrap(err, "creating transport")
}
proxy := createReverseProxy(endpoint, transport, tlsClientConf == nil, forceDowngrade)
proxy := createReverseProxy(endpoint, transport, tlsClientConf == nil, forceDowngrade, contentType)
return makeProxyServer(proxy)
}

Expand All @@ -171,7 +179,7 @@ func ConnectViaProxy(ctx context.Context, endpoint string, tlsClientConf *tls.Co
if connectOpts.useWebSocket {
proxy, dialCtx, err = createClientWSProxy(endpoint, tlsClientConf)
} else {
proxy, dialCtx, err = createClientProxy(endpoint, tlsClientConf, connectOpts.forceHTTP2, connectOpts.forceDowngrade, connectOpts.extraH2ALPNs)
proxy, dialCtx, err = createClientProxy(endpoint, tlsClientConf, connectOpts.forceHTTP2, connectOpts.forceDowngrade, connectOpts.extraH2ALPNs, connectOpts.contentType)
}

if err != nil {
Expand Down
12 changes: 11 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,26 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, op
return
}

if contentType, _ := stringutils.Split2(req.Header.Get("Content-Type"), "+"); contentType != "application/grpc" {
if !isContentTypeValid(req.Header.Get("Content-Type")) {
// Non-gRPC request to the same port.
httpHandler.ServeHTTP(w, req)
return
}

// explicitly set application/grpc content type,
// because underlying grpc library supports only it.
// https://github.com/grpc/grpc-go/blob/9deee9ba5f5b654d38c737c701181dceebb57e44/internal/grpcutil/method.go#L61
vikin91 marked this conversation as resolved.
Show resolved Hide resolved
req.Header.Set("Content-Type", "application/grpc")

handleGRPCWeb(w, req, validGRPCWebPaths, grpcSrv, &serverOpts)
})
}

func isContentTypeValid(contentType string) bool {
ct, _ := stringutils.Split2(contentType, "+")
return ct == "application/grpc" || ct == "application/grpc-web"
}

func isWebSocketUpgrade(header http.Header) (bool, error) {
if header.Get("Sec-Websocket-Protocol") != grpcwebsocket.SubprotocolName {
return false, nil
Expand Down
Loading