Skip to content

Commit

Permalink
feat: Add WithForwardResponseRewriter to allow easier/more useful res…
Browse files Browse the repository at this point in the history
…ponse control (#4622)

## Context/Background
I am working with grpc-gateway to mimick an older REST API while
implementing the core as a gRPC server. This REST API, by convention,
emits a response envelope.

After extensively researching grpc-gateway's code and options, I think
that there is not a good enough way to address my use-case; for a few
more nuanced reasons.

The sum of my requirements are as follows:
- 1) Allow a particular field of a response message to be used as the main
response. ✅ This is handled with `response_body` annotation.
- 2) Be able to run interface checks on the response to extract useful
information, like `next_page_token` [0] and surface it in the final
response envelope. (`npt, ok := resp.(interface { GetNextPageToken() string })`).
- 3) Take the true result and place it into a response envelope along with
other parts of the response by convention and let that be encoded and
sent as the response instead.

 ### Implementing a response envelope with `Marshaler`
My first attempt at getting my gRPC server's responses in an envelope
led me to implement my own Marshaler, I have seen this approach
discussed in #4483.

This does satisfy requirements 1 and 3 just fine, since the HTTP
annotations helpfully allow the code to only receive the true result,
and the Marshal interface has enough capabilities to take that and wrap
it in a response envelope.

However, requirements 1 and 2 are not _both_ satisfiable with the
current grpc-gateway code because of how the `XXX_ResponseBody()` is
called _before_ passing to the `Marshal(v)` function. This strips out
the other fields that I would normally be able to detect and place in
the response envelope.

I even tried creating my _own_ protobuf message extension that would let
me define another way of defining the "true result" field. But the
options for implementing that are either a _ton_ of protoreflect at
runtime to detect and extract that, or I am writing another protobuf
generator plugin (which I have done before [1]), but both of those
options seem quite complex.

 ### Other non-starter options
Just to get ahead of the discussion, `WithForwardResponseOption` clearly
was not meant for this use-case. At best, it seems to only be a way to
take information that might be in the response and add it as a header.

[0]: https://google.aip.dev/158#:~:text=Response%20messages%20for%20collections%20should%20define%20a%20string%20next_page_token%20field
[1]: https://github.com/nkcmr/protoc-gen-twirp_js

 ### In practice
This change fulfills my requirements by allowing logic to be inserted
right before the Marshal is called:

```go
gatewayMux := runtime.NewServeMux(
  runtime.WithForwardResponseRewriter(func(ctx context.Context, response proto.Message) (interface{}, error) {
    if s, ok := response.(*statuspb.Status); ok {
      return rewriteStatusToErrorEnvelope(ctx, s)
    }
    return rewriteResultToEnvelope(ctx, response)
  }),
)
```

 ## In this PR
This PR introduces a new `ServeMuxOption` called
`WithForwardResponseRewriter` that allows for a user-provided function
to be supplied that can take a response `proto.Message` and return `any`
during unary response forwarding, stream response forwarding, and error
response forwarding.

The code generation was also updated to make the `XXX_ResponseBody()`
response wrappers embed the concrete type instead of just
`proto.Message`. This allows any code in response rewriter functions to
be able to have access to the original type, so that interface checks
against it should pass as if it was the original message.

Updated the "Customizing Your Gateway" documentation to use
`WithForwardResponseRewriter` in the `Fully Overriding Custom HTTP
Responses` sections.

 ## Testing
Added some basic unit tests to ensure Unary/Stream and error handlers
invoke `ForwardResponseRewriter` correctly.
  • Loading branch information
nkcmr authored Aug 16, 2024
1 parent 169370b commit a1b0988
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 74 deletions.
31 changes: 14 additions & 17 deletions docs/docs/mapping/customizing_your_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ First, set up the gRPC-Gateway with the custom options:

```go
mux := runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &ResponseWrapper{}),
runtime.WithForwardResponseOption(forwardResponse),
runtime.WithForwardResponseOption(setStatus),
runtime.WithForwardResponseRewriter(responseEnvelope),
)
```

Define the `forwardResponse` function to handle specific response types:
Define the `setStatus` function to handle specific response types:

```go
func forwardResponse(ctx context.Context, w http.ResponseWriter, m protoreflect.ProtoMessage) error {
func setStatus(ctx context.Context, w http.ResponseWriter, m protoreflect.ProtoMessage) error {
switch v := m.(type) {
case *pb.CreateUserResponse:
w.WriteHeader(http.StatusCreated)
Expand All @@ -342,32 +342,29 @@ func forwardResponse(ctx context.Context, w http.ResponseWriter, m protoreflect.
}
```

Create a custom marshaler to format the response data which utilizes the `JSONPb` marshaler as a fallback:
Define the `responseEnvelope` function to rewrite the response to a different type/shape:

```go
type ResponseWrapper struct {
runtime.JSONPb
}

func (c *ResponseWrapper) Marshal(data any) ([]byte, error) {
resp := data
func responseEnvelope(_ context.Context, response proto.Message) (interface{}, error) {
switch v := data.(type) {
case *pb.CreateUserResponse:
// wrap the response in a custom structure
resp = map[string]any{
return map[string]any{
"success": true,
"data": data,
}
}, nil
}
// otherwise, use the default JSON marshaller
return c.JSONPb.Marshal(resp)
return response, nil
}
```

In this setup:

- The `forwardResponse` function intercepts the response and formats it as needed.
- The `CustomPB` marshaller ensures that specific types of responses are wrapped in a custom structure before being sent to the client.
- The `setStatus` function intercepts the response and uses its type to send `201 Created` only when it sees `*pb.CreateUserResponse`.
- The `responseEnvelope` function ensures that specific types of responses are wrapped in a custom structure before being sent to the client.

**NOTE:** Using `WithForwardResponseRewriter` is partially incompatible with OpenAPI annotations. Because response
rewriting happens at runtime, it is not possible to represent that in `protoc-gen-openapiv2` output.

## Error handler

Expand Down
32 changes: 14 additions & 18 deletions examples/internal/proto/examplepb/response_body_service.pb.gw.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions protoc-gen-grpc-gateway/internal/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server(ctx context.Context,
}
{{ if $b.ResponseBody }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})}, mux.GetForwardResponseOptions()...)
{{ else }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
{{end}}
Expand Down Expand Up @@ -744,7 +744,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
{{end}}
{{else}}
{{ if $b.ResponseBody }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})}, mux.GetForwardResponseOptions()...)
{{ else }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
{{end}}
Expand All @@ -759,12 +759,11 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
{{range $b := $m.Bindings}}
{{if $b.ResponseBody}}
type response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} struct {
proto.Message
*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}}
}
func (m response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}) XXX_ResponseBody() interface{} {
response := m.Message.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})
return {{$b.ResponseBody.AssignableExpr "response" $m.Service.File.GoPkg.Path}}
return {{$b.ResponseBody.AssignableExpr "m" $m.Service.File.GoPkg.Path}}
}
{{end}}
{{end}}
Expand Down
16 changes: 13 additions & 3 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,36 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
// return Internal when Marshal failed
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
const fallbackRewriter = `{"code": 13, "message": "failed to rewrite error message"}`

var customStatus *HTTPStatusError
if errors.As(err, &customStatus) {
err = customStatus.Err
}

s := status.Convert(err)
pb := s.Proto()

w.Header().Del("Trailer")
w.Header().Del("Transfer-Encoding")

contentType := marshaler.ContentType(pb)
respRw, err := mux.forwardResponseRewriter(ctx, s.Proto())
if err != nil {
grpclog.Errorf("Failed to rewrite error message %q: %v", s, err)
w.WriteHeader(http.StatusInternalServerError)
if _, err := io.WriteString(w, fallbackRewriter); err != nil {
grpclog.Errorf("Failed to write response: %v", err)
}
return
}

contentType := marshaler.ContentType(respRw)
w.Header().Set("Content-Type", contentType)

if s.Code() == codes.Unauthenticated {
w.Header().Set("WWW-Authenticate", s.Message())
}

buf, merr := marshaler.Marshal(pb)
buf, merr := marshaler.Marshal(respRw)
if merr != nil {
grpclog.Errorf("Failed to marshal error message %q: %v", s, merr)
w.WriteHeader(http.StatusInternalServerError)
Expand Down
47 changes: 36 additions & 11 deletions runtime/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
statuspb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

func TestDefaultHTTPError(t *testing.T) {
Expand All @@ -24,12 +25,14 @@ func TestDefaultHTTPError(t *testing.T) {
)

for i, spec := range []struct {
err error
status int
msg string
marshaler runtime.Marshaler
contentType string
details string
err error
status int
msg string
marshaler runtime.Marshaler
contentType string
details string
fordwardRespRewriter runtime.ForwardResponseRewriter
extractMessage func(*testing.T)
}{
{
err: errors.New("example error"),
Expand Down Expand Up @@ -70,23 +73,45 @@ func TestDefaultHTTPError(t *testing.T) {
contentType: "application/json",
msg: "Method Not Allowed",
},
{
err: status.Error(codes.InvalidArgument, "example error"),
status: http.StatusBadRequest,
marshaler: &runtime.JSONPb{},
contentType: "application/json",
msg: "bad request: example error",
fordwardRespRewriter: func(ctx context.Context, response proto.Message) (any, error) {
if s, ok := response.(*statuspb.Status); ok && strings.HasPrefix(s.Message, "example") {
return &statuspb.Status{
Code: s.Code,
Message: "bad request: " + s.Message,
Details: s.Details,
}, nil
}
return response, nil
},
},
} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, "", "", nil) // Pass in an empty request to match the signature
mux := runtime.NewServeMux()
marshaler := &runtime.JSONPb{}
runtime.HTTPError(ctx, mux, marshaler, w, req, spec.err)

if got, want := w.Header().Get("Content-Type"), "application/json"; got != want {
opts := []runtime.ServeMuxOption{}
if spec.fordwardRespRewriter != nil {
opts = append(opts, runtime.WithForwardResponseRewriter(spec.fordwardRespRewriter))
}
mux := runtime.NewServeMux(opts...)

runtime.HTTPError(ctx, mux, spec.marshaler, w, req, spec.err)

if got, want := w.Header().Get("Content-Type"), spec.contentType; got != want {
t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err)
}
if got, want := w.Code, spec.status; got != want {
t.Errorf("w.Code = %d; want %d", got, want)
}

var st statuspb.Status
if err := marshaler.Unmarshal(w.Body.Bytes(), &st); err != nil {
if err := spec.marshaler.Unmarshal(w.Body.Bytes(), &st); err != nil {
t.Errorf("marshaler.Unmarshal(%q, &body) failed with %v; want success", w.Body.Bytes(), err)
return
}
Expand Down
28 changes: 20 additions & 8 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,27 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}

respRw, err := mux.forwardResponseRewriter(ctx, resp)
if err != nil {
grpclog.Errorf("Rewrite error: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}

if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(resp))
w.Header().Set("Content-Type", marshaler.ContentType(respRw))
}

var buf []byte
httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
httpBody, isHTTPBody := respRw.(*httpbody.HttpBody)
switch {
case resp == nil:
case respRw == nil:
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
case isHTTPBody:
buf = httpBody.GetData()
default:
result := map[string]interface{}{"result": resp}
if rb, ok := resp.(responseBody); ok {
result := map[string]interface{}{"result": respRw}
if rb, ok := respRw.(responseBody); ok {
result["result"] = rb.XXX_ResponseBody()
}

Expand Down Expand Up @@ -165,12 +172,17 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
respRw, err := mux.forwardResponseRewriter(ctx, resp)
if err != nil {
grpclog.Errorf("Rewrite error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var buf []byte
var err error
if rb, ok := resp.(responseBody); ok {
if rb, ok := respRw.(responseBody); ok {
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
} else {
buf, err = marshaler.Marshal(resp)
buf, err = marshaler.Marshal(respRw)
}
if err != nil {
grpclog.Errorf("Marshal error: %v", err)
Expand Down
Loading

0 comments on commit a1b0988

Please sign in to comment.