Skip to content

Commit

Permalink
apply patch from fuweid
Browse files Browse the repository at this point in the history
Signed-off-by: Benjamin Wang <[email protected]>
  • Loading branch information
ahrtr committed Aug 23, 2023
1 parent fa21630 commit a9f1490
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 10 deletions.
292 changes: 292 additions & 0 deletions pkg/legacygwjsonpb/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
package legacygwjsonpb

import (
"bytes"
"encoding/json"
"fmt"
"io"
"reflect"

"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
gw "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
protoV2 "google.golang.org/protobuf/proto"
)

// JSONPb is a Marshaler which marshals/unmarshals into/from JSON
// with the "github.com/golang/protobuf/jsonpb".
// It supports fully functionality of protobuf unlike JSONBuiltin.
//
// The NewDecoder method returns a DecoderWrapper, so the underlying
// *json.Decoder methods can be used.
type JSONPb jsonpb.Marshaler

// ContentType always returns "application/json".
func (*JSONPb) ContentType(v interface{}) string {
return "application/json"
}

// Marshal marshals "v" into JSON.
func (j *JSONPb) Marshal(vv interface{}) (ret []byte, retErr error) {
var v interface{} = vv

// For unary api, the gateway always convert the messageV1 into V2.
// We should convert it back. And for the streaming api, the input is
// kind of map, we can't just call the proto.MessageV1 because it
// will panic :)
//
// REF: github.com/grpc-ecosystem/grpc-gateway/[email protected]/runtime/handler.goL75
if _, ok := vv.(protoV2.Message); ok {
v = proto.MessageV1(vv)
}

if _, ok := v.(proto.Message); !ok {
return j.marshalNonProtoField(v)
}

var buf bytes.Buffer
if err := j.marshalTo(&buf, v); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
buf, err := j.marshalNonProtoField(v)
if err != nil {
return err
}
_, err = w.Write(buf)
return err
}
return (*jsonpb.Marshaler)(j).Marshal(w, p)
}

var (
// protoMessageType is stored to prevent constant lookup of the same type at runtime.
protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
)

// marshalNonProto marshals a non-message field of a protobuf message.
// This function does not correctly marshals arbitrary data structure into JSON,
// but it is only capable of marshaling non-message field values of protobuf,
// i.e. primitive types, enums; pointers to primitives or enums; maps from
// integer/string types to primitives/enums/pointers to messages.
func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
rv := reflect.ValueOf(v)
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return []byte("null"), nil
}
rv = rv.Elem()
}

if rv.Kind() == reflect.Slice {
if rv.IsNil() {
if j.EmitDefaults {
return []byte("[]"), nil
}
return []byte("null"), nil
}

if rv.Type().Elem().Implements(protoMessageType) {
var buf bytes.Buffer
err := buf.WriteByte('[')
if err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
err = buf.WriteByte(',')
if err != nil {
return nil, err
}
}
if err = (*jsonpb.Marshaler)(j).Marshal(&buf, rv.Index(i).Interface().(proto.Message)); err != nil {
return nil, err
}
}
err = buf.WriteByte(']')
if err != nil {
return nil, err
}

return buf.Bytes(), nil
}
}

if rv.Kind() == reflect.Map {
m := make(map[string]*json.RawMessage)
for _, k := range rv.MapKeys() {
buf, err := j.Marshal(rv.MapIndex(k).Interface())
if err != nil {
return nil, err
}
m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
}
if j.Indent != "" {
return json.MarshalIndent(m, "", j.Indent)
}
return json.Marshal(m)
}
if enum, ok := rv.Interface().(protoEnum); ok && !j.EnumsAsInts {
return json.Marshal(enum.String())
}
return json.Marshal(rv.Interface())
}

// Unmarshal unmarshals JSON "data" into "v"
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, v)
}

// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) gw.Decoder {
d := json.NewDecoder(r)
return DecoderWrapper{Decoder: d}
}

// DecoderWrapper is a wrapper around a *json.Decoder that adds
// support for protos to the Decode method.
type DecoderWrapper struct {
*json.Decoder
}

// Decode wraps the embedded decoder's Decode method to support
// protos using a jsonpb.Unmarshaler.
func (d DecoderWrapper) Decode(v interface{}) error {
return decodeJSONPb(d.Decoder, v)
}

// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONPb) NewEncoder(w io.Writer) gw.Encoder {
return gw.EncoderFunc(func(vv interface{}) error {
v := proto.MessageV1(vv)

if err := j.marshalTo(w, v); err != nil {
return err
}
// mimic json.Encoder by adding a newline (makes output
// easier to read when it contains multiple encoded items)
_, err := w.Write(j.Delimiter())
return err
})
}

func unmarshalJSONPb(data []byte, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, v)
}

func decodeJSONPb(d *json.Decoder, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, v)
}
unmarshaler := &jsonpb.Unmarshaler{AllowUnknownFields: allowUnknownFields}
return unmarshaler.UnmarshalNext(d, p)
}

func decodeNonProtoField(d *json.Decoder, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
}
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
if rv.Type().ConvertibleTo(typeProtoMessage) {
unmarshaler := &jsonpb.Unmarshaler{AllowUnknownFields: allowUnknownFields}
return unmarshaler.UnmarshalNext(d, rv.Interface().(proto.Message))
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Map {
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
conv, ok := convFromType[rv.Type().Key().Kind()]
if !ok {
return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
}

m := make(map[string]*json.RawMessage)
if err := d.Decode(&m); err != nil {
return err
}
for k, v := range m {
result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(*v), bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
}
return nil
}
if _, ok := rv.Interface().(protoEnum); ok {
var repr interface{}
if err := d.Decode(&repr); err != nil {
return err
}
switch repr.(type) {
case string:
// TODO(yugui) Should use proto.StructProperties?
return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", repr, rv.Interface())
case float64:
rv.Set(reflect.ValueOf(int32(repr.(float64))).Convert(rv.Type()))
return nil
default:
return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface())
}
}
return d.Decode(v)
}

type protoEnum interface {
fmt.Stringer
EnumDescriptor() ([]byte, []int)
}

var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()

// Delimiter for newline encoded JSON streams.
func (j *JSONPb) Delimiter() []byte {
return []byte("\n")
}

// allowUnknownFields helps not to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
var allowUnknownFields = true

// DisallowUnknownFields enables option in decoder (unmarshaller) to
// return an error when it finds an unknown field. This function must be
// called before using the JSON marshaller.
func DisallowUnknownFields() {
allowUnknownFields = false
}

var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(gw.String),
reflect.Bool: reflect.ValueOf(gw.Bool),
reflect.Float64: reflect.ValueOf(gw.Float64),
reflect.Float32: reflect.ValueOf(gw.Float32),
reflect.Int64: reflect.ValueOf(gw.Int64),
reflect.Int32: reflect.ValueOf(gw.Int32),
reflect.Uint64: reflect.ValueOf(gw.Uint64),
reflect.Uint32: reflect.ValueOf(gw.Uint32),
reflect.Slice: reflect.ValueOf(gw.Bytes),
}
)
4 changes: 2 additions & 2 deletions scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ function grpcproxy_pass {

function grpcproxy_integration_pass {
# shellcheck disable=SC2068
run_for_module "tests" go_test "./integration/..." "fail_fast" : -timeout=30m -tags cluster_proxy ${COMMON_TEST_FLAGS[@]:-} "$@"
run_for_module "tests" go_test "./integration/..." "fail_fast" : -timeout=30m -tags cluster_proxy ${COMMON_TEST_FLAGS[@]:-} ${RUN_ARG[@]:-} "$@"
}

function grpcproxy_e2e_pass {
# shellcheck disable=SC2068
run_for_module "tests" go_test "./e2e" "fail_fast" : -timeout=30m -tags cluster_proxy ${COMMON_TEST_FLAGS[@]:-} "$@"
run_for_module "tests" go_test "./e2e" "fail_fast" : -timeout=30m -tags cluster_proxy ${COMMON_TEST_FLAGS[@]:-} ${RUN_ARG[@]:-} "$@"
}

################# COVERAGE #####################################################
Expand Down
8 changes: 7 additions & 1 deletion server/embed/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"go.etcd.io/etcd/client/pkg/v3/transport"
"go.etcd.io/etcd/pkg/v3/debugutil"
"go.etcd.io/etcd/pkg/v3/httputil"
"go.etcd.io/etcd/pkg/v3/legacygwjsonpb"
"go.etcd.io/etcd/server/v3/config"
"go.etcd.io/etcd/server/v3/etcdserver"
"go.etcd.io/etcd/server/v3/etcdserver/api/v3client"
Expand Down Expand Up @@ -301,7 +302,12 @@ func (sctx *serveCtx) registerGateway(dial func(ctx context.Context) (*grpc.Clie
if err != nil {
return nil, err
}
gwmux := gw.NewServeMux()

gwmux := gw.NewServeMux(
gw.WithMarshalerOption(gw.MIMEWildcard, &gw.HTTPBodyMarshaler{
Marshaler: &legacygwjsonpb.JSONPb{OrigName: true},
}),
)

handlers := []registerHandlerFunc{
etcdservergw.RegisterKVHandler,
Expand Down
13 changes: 6 additions & 7 deletions tests/e2e/v3_curl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ import (
pb "go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
"go.etcd.io/etcd/client/pkg/v3/testutil"
"go.etcd.io/etcd/pkg/v3/legacygwjsonpb"
epb "go.etcd.io/etcd/server/v3/etcdserver/api/v3election/v3electionpb"
"go.etcd.io/etcd/tests/v3/framework/e2e"

"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
)

var apiPrefix = []string{"/v3"}
Expand Down Expand Up @@ -168,7 +167,7 @@ func testV3CurlTxn(cx ctlCtx) {
},
},
}
m := &runtime.JSONPb{}
m := &legacygwjsonpb.JSONPb{}
jsonDat, jerr := m.Marshal(txn)
if jerr != nil {
cx.t.Fatal(jerr)
Expand All @@ -181,7 +180,7 @@ func testV3CurlTxn(cx ctlCtx) {

// was crashing etcd server
malformed := `{"compare":[{"result":0,"target":1,"key":"Zm9v","TargetUnion":null}],"success":[{"Request":{"RequestPut":{"key":"Zm9v","value":"YmFy"}}}]}`
if err := e2e.CURLPost(cx.epc, e2e.CURLReq{Endpoint: path.Join(p, "/kv/txn"), Value: malformed, Expected: "error"}); err != nil {
if err := e2e.CURLPost(cx.epc, e2e.CURLReq{Endpoint: path.Join(p, "/kv/txn"), Value: malformed, Expected: `"code":3,"message":"etcdserver: key not found"`}); err != nil {
cx.t.Fatalf("failed testV3CurlTxn put with curl using prefix (%s) (%v)", p, err)
}

Expand Down Expand Up @@ -232,7 +231,7 @@ func testV3CurlAuth(cx ctlCtx) {
testutil.AssertNil(cx.t, err)

// fail put no auth
if err = e2e.CURLPost(cx.epc, e2e.CURLReq{Endpoint: path.Join(p, "/kv/put"), Value: string(putreq), Expected: "error"}); err != nil {
if err = e2e.CURLPost(cx.epc, e2e.CURLReq{Endpoint: path.Join(p, "/kv/put"), Value: string(putreq), Expected: `"code":3,"message":"etcdserver: user name is empty"`}); err != nil {
cx.t.Fatalf("failed testV3CurlAuth no auth put with curl using prefix (%s) (%v)", p, err)
}

Expand Down Expand Up @@ -347,7 +346,7 @@ func testV3CurlProclaimMissiongLeaderKey(cx ctlCtx) {
if err = e2e.CURLPost(cx.epc, e2e.CURLReq{
Endpoint: path.Join(cx.apiPrefix, "/election/proclaim"),
Value: string(pdata),
Expected: `{"error":"\"leader\" field must be provided","code":2,"message":"\"leader\" field must be provided"}`,
Expected: `{"code":2,"message":"\"leader\" field must be provided"}`,
}); err != nil {
cx.t.Fatalf("failed post proclaim request (%s) (%v)", cx.apiPrefix, err)
}
Expand All @@ -363,7 +362,7 @@ func testV3CurlResignMissiongLeaderKey(cx ctlCtx) {
if err := e2e.CURLPost(cx.epc, e2e.CURLReq{
Endpoint: path.Join(cx.apiPrefix, "/election/resign"),
Value: `{}`,
Expected: `{"error":"\"leader\" field must be provided","code":2,"message":"\"leader\" field must be provided"}`,
Expected: `{"code":2,"message":"\"leader\" field must be provided"}`,
}); err != nil {
cx.t.Fatalf("failed post resign request (%s) (%v)", cx.apiPrefix, err)
}
Expand Down

0 comments on commit a9f1490

Please sign in to comment.