diff --git a/_integration-tests/echo_service_test.go b/_integration-tests/echo_service_test.go index f7ab463..72cb914 100644 --- a/_integration-tests/echo_service_test.go +++ b/_integration-tests/echo_service_test.go @@ -192,6 +192,50 @@ func testWithEchoService(t *testing.T, serverPreferGRPCWeb bool) { expectClientStreamOK: false, expectBidiStreamOK: false, }, + { + targetID: "downgrading-grpc", + behindHTTP1ReverseProxy: false, + useProxy: true, + forceDowngrade: true, + customContentType: "application/grpc-web", + expectUnaryOK: true, + expectServerStreamOK: true, + expectClientStreamOK: false, + expectBidiStreamOK: false, + }, + { + targetID: "downgrading-grpc", + behindHTTP1ReverseProxy: true, + useProxy: true, + forceDowngrade: true, + customContentType: "application/grpc-web", + expectUnaryOK: true, + expectServerStreamOK: true, + expectClientStreamOK: false, + expectBidiStreamOK: false, + }, + { + targetID: "downgrading-grpc", + behindHTTP1ReverseProxy: true, + useProxy: true, + forceDowngrade: false, + customContentType: "application/grpc-web", + expectUnaryOK: true, + expectServerStreamOK: true, + expectClientStreamOK: false, + expectBidiStreamOK: false, + }, + { + targetID: "downgrading-grpc", + behindHTTP1ReverseProxy: true, + useProxy: true, + forceDowngrade: true, + customContentType: "dummy", + expectUnaryOK: false, + expectServerStreamOK: false, + expectClientStreamOK: false, + expectBidiStreamOK: false, + }, } for _, c := range cases { @@ -317,6 +361,7 @@ type testCase struct { useProxy bool useWebSocket bool forceDowngrade bool + customContentType string expectUnaryOK bool expectClientStreamOK bool @@ -343,6 +388,10 @@ func (c *testCase) Name() string { } else { sb.WriteString("-direct") } + + if len(c.customContentType) > 0 { + sb.WriteString("-custom-content-type") + } return sb.String() } @@ -381,6 +430,10 @@ func (c *testCase) Run(t *testing.T, cfg *testConfig) { } opts = append(opts, client.UseWebSocket(c.useWebSocket), client.ForceDowngrade(c.forceDowngrade)) + if len(c.customContentType) > 0 { + opts = append(opts, client.WithContentType(c.customContentType)) + } + cc, err = client.ConnectViaProxy(ctx, targetAddr, nil, opts...) } else { cc, err = grpc.DialContext(ctx, targetAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/client/options.go b/client/options.go index 449fbd2..fdfe6c8 100644 --- a/client/options.go +++ b/client/options.go @@ -22,6 +22,7 @@ type connectOptions struct { forceHTTP2 bool forceDowngrade bool useWebSocket bool + contentType string } // ConnectOption is an option that can be passed to the `ConnectViaProxy` method. @@ -66,6 +67,12 @@ func ForceDowngrade(force bool) ConnectOption { return forceDowngradeOption(force) } +// WithContentType returns a connection option that instructs the +// client to use a custom content type for sending requests to the server. +func WithContentType(contentType string) ConnectOption { + return contentTypeOption(contentType) +} + type dialOptsOption []grpc.DialOption func (o dialOptsOption) apply(opts *connectOptions) { @@ -95,3 +102,9 @@ type forceDowngradeOption bool func (o forceDowngradeOption) apply(opts *connectOptions) { opts.forceDowngrade = bool(o) } + +type contentTypeOption string + +func (o contentTypeOption) apply(opts *connectOptions) { + opts.contentType = string(o) +} diff --git a/client/proxy.go b/client/proxy.go index 1ba7489..1b9270e 100644 --- a/client/proxy.go +++ b/client/proxy.go @@ -79,7 +79,7 @@ func writeError(w http.ResponseWriter, err error) { w.Header().Set("Grpc-Message", grpcproto.EncodeGrpcMessage(errMsg)) } -func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool) *httputil.ReverseProxy { +func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool, contentType string) *httputil.ReverseProxy { scheme := "https" if insecure { scheme = "http" @@ -95,6 +95,14 @@ func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, req.Header.Add("Accept", "application/grpc") } req.Header.Add("Accept", "application/grpc-web") + + if len(contentType) > 0 { + // Replacing old content type (e.g., application/grpc), to an overridden content type. + // Without removing old header, some gRPC-Web servers will not work, + // because an HTTP client will send both old and new header values. + req.Header.Set("Content-Type", contentType) + } + req.URL.Scheme = scheme req.URL.Host = endpoint }, @@ -142,12 +150,12 @@ func createTransport(tlsClientConf *tls.Config, forceHTTP2 bool, extraH2ALPNs [] return transport, nil } -func createClientProxy(endpoint string, tlsClientConf *tls.Config, forceHTTP2, forceDowngrade bool, extraH2ALPNs []string) (*http.Server, pipeconn.DialContextFunc, error) { +func createClientProxy(endpoint string, tlsClientConf *tls.Config, forceHTTP2, forceDowngrade bool, extraH2ALPNs []string, contentType string) (*http.Server, pipeconn.DialContextFunc, error) { transport, err := createTransport(tlsClientConf, forceHTTP2, extraH2ALPNs) if err != nil { return nil, nil, errors.Wrap(err, "creating transport") } - proxy := createReverseProxy(endpoint, transport, tlsClientConf == nil, forceDowngrade) + proxy := createReverseProxy(endpoint, transport, tlsClientConf == nil, forceDowngrade, contentType) return makeProxyServer(proxy) } @@ -171,7 +179,7 @@ func ConnectViaProxy(ctx context.Context, endpoint string, tlsClientConf *tls.Co if connectOpts.useWebSocket { proxy, dialCtx, err = createClientWSProxy(endpoint, tlsClientConf) } else { - proxy, dialCtx, err = createClientProxy(endpoint, tlsClientConf, connectOpts.forceHTTP2, connectOpts.forceDowngrade, connectOpts.extraH2ALPNs) + proxy, dialCtx, err = createClientProxy(endpoint, tlsClientConf, connectOpts.forceHTTP2, connectOpts.forceDowngrade, connectOpts.extraH2ALPNs, connectOpts.contentType) } if err != nil { diff --git a/server/server.go b/server/server.go index 140aadb..b9dd356 100644 --- a/server/server.go +++ b/server/server.go @@ -182,16 +182,25 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, op return } - if contentType, _ := stringutils.Split2(req.Header.Get("Content-Type"), "+"); contentType != "application/grpc" { + if !isContentTypeValid(req.Header.Get("Content-Type")) { // Non-gRPC request to the same port. httpHandler.ServeHTTP(w, req) return } + // Internally content type must be application/grpc, + // See: https://github.com/grpc/grpc-go/blob/9deee9b/internal/grpcutil/method.go#L61 + req.Header.Set("Content-Type", "application/grpc") + handleGRPCWeb(w, req, validGRPCWebPaths, grpcSrv, &serverOpts) }) } +func isContentTypeValid(contentType string) bool { + ct, _ := stringutils.Split2(contentType, "+") + return ct == "application/grpc" || ct == "application/grpc-web" +} + func isWebSocketUpgrade(header http.Header) (bool, error) { if header.Get("Sec-Websocket-Protocol") != grpcwebsocket.SubprotocolName { return false, nil