Skip to content

Commit

Permalink
Make remoteDataConverter implement PayloadCodec (#1303)
Browse files Browse the repository at this point in the history
* Add NewRemotePayloadCodec and RemotePayloadCodecOptions API

* Rearrange code
  • Loading branch information
dandavison authored Dec 4, 2023
1 parent eb05747 commit 987379d
Showing 1 changed file with 85 additions and 66 deletions.
151 changes: 85 additions & 66 deletions converter/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,27 +347,100 @@ func NewPayloadCodecHTTPHandler(e ...PayloadCodec) http.Handler {
return &codecHTTPHandler{codecs: e}
}

// RemoteDataConverterOptions are options for NewRemoteDataConverter.
// RemotePayloadCodecOptions are options for RemotePayloadCodec.
// Client is optional.
type RemotePayloadCodecOptions struct {
Endpoint string
ModifyRequest func(*http.Request) error
Client http.Client
}

type remotePayloadCodec struct {
options RemotePayloadCodecOptions
}

// NewRemotePayloadCodec creates a PayloadCodec using the remote endpoint configured by RemotePayloadCodecOptions.
func NewRemotePayloadCodec(options RemotePayloadCodecOptions) PayloadCodec {
return &remotePayloadCodec{options}
}

// Encode uses the remote payload codec endpoint to encode payloads.
func (pc *remotePayloadCodec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
return pc.encodeOrDecode(pc.options.Endpoint+remotePayloadCodecEncodePath, payloads)
}

// Decode uses the remote payload codec endpoint to decode payloads.
func (pc *remotePayloadCodec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
return pc.encodeOrDecode(pc.options.Endpoint+remotePayloadCodecDecodePath, payloads)
}

func (pc *remotePayloadCodec) encodeOrDecode(endpoint string, payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
requestPayloads, err := json.Marshal(commonpb.Payloads{Payloads: payloads})
if err != nil {
return payloads, fmt.Errorf("unable to marshal payloads: %w", err)
}

req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(requestPayloads))
if err != nil {
return payloads, fmt.Errorf("unable to build request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

if pc.options.ModifyRequest != nil {
err = pc.options.ModifyRequest(req)
if err != nil {
return payloads, err
}
}

response, err := pc.options.Client.Do(req)
if err != nil {
return payloads, err
}
defer func() { _ = response.Body.Close() }()

if response.StatusCode == 200 {
bs, err := io.ReadAll(response.Body)
if err != nil {
return payloads, fmt.Errorf("failed to read response body: %w", err)
}
var resultPayloads commonpb.Payloads
err = protojson.Unmarshal(bs, &resultPayloads)
if err != nil {
return payloads, fmt.Errorf("unable to unmarshal payloads: %w", err)
}
if len(payloads) != len(resultPayloads.Payloads) {
return payloads, fmt.Errorf("received %d payloads from remote codec, expected %d", len(resultPayloads.Payloads), len(payloads))
}
return resultPayloads.Payloads, nil
}

message, _ := io.ReadAll(response.Body)
return payloads, fmt.Errorf("%s: %s", http.StatusText(response.StatusCode), message)
}

// Fields Endpoint, ModifyRequest, Client of RemotePayloadCodecOptions are also
// exposed here in RemoteDataConverterOptions for backwards compatibility.

// RemoteDataConverterOptions are options for NewRemoteDataConverter.
type RemoteDataConverterOptions struct {
Endpoint string
ModifyRequest func(*http.Request) error
Client http.Client
}

// remoteDataConverter is a DataConverter that wraps an underlying data
// converter and uses a remote codec to handle encoding/decoding.
type remoteDataConverter struct {
parent DataConverter
options RemoteDataConverterOptions
parent DataConverter
payloadCodec PayloadCodec
}

// NewRemoteDataConverter wraps the given parent DataConverter and performs
// encoding/decoding on the payload via the remote endpoint.
func NewRemoteDataConverter(parent DataConverter, options RemoteDataConverterOptions) DataConverter {
options.Endpoint = strings.TrimSuffix(options.Endpoint, "/")

return &remoteDataConverter{parent, options}
payloadCodec := NewRemotePayloadCodec(RemotePayloadCodecOptions(options))
return &remoteDataConverter{parent, payloadCodec}
}

// ToPayload implements DataConverter.ToPayload performing remote encoding on the
Expand All @@ -377,7 +450,7 @@ func (rdc *remoteDataConverter) ToPayload(value interface{}) (*commonpb.Payload,
if payload == nil || err != nil {
return payload, err
}
encodedPayloads, err := rdc.encodePayloads([]*commonpb.Payload{payload})
encodedPayloads, err := rdc.payloadCodec.Encode([]*commonpb.Payload{payload})
if err != nil {
return payload, err
}
Expand All @@ -391,14 +464,14 @@ func (rdc *remoteDataConverter) ToPayloads(value ...interface{}) (*commonpb.Payl
if payloads == nil || err != nil {
return payloads, err
}
encodedPayloads, err := rdc.encodePayloads(payloads.Payloads)
encodedPayloads, err := rdc.payloadCodec.Encode(payloads.Payloads)
return &commonpb.Payloads{Payloads: encodedPayloads}, err
}

// FromPayload implements DataConverter.FromPayload performing remote decoding on the
// given payload before sending to the parent FromPayload.
func (rdc *remoteDataConverter) FromPayload(payload *commonpb.Payload, valuePtr interface{}) error {
decodedPayloads, err := rdc.decodePayloads([]*commonpb.Payload{payload})
decodedPayloads, err := rdc.payloadCodec.Decode([]*commonpb.Payload{payload})
if err != nil {
return err
}
Expand All @@ -412,7 +485,7 @@ func (rdc *remoteDataConverter) FromPayloads(payloads *commonpb.Payloads, valueP
return rdc.parent.FromPayloads(payloads, valuePtrs...)
}

decodedPayloads, err := rdc.decodePayloads(payloads.Payloads)
decodedPayloads, err := rdc.payloadCodec.Decode(payloads.Payloads)
if err != nil {
return err
}
Expand All @@ -426,7 +499,7 @@ func (rdc *remoteDataConverter) ToString(payload *commonpb.Payload) string {
return rdc.parent.ToString(payload)
}

decodedPayloads, err := rdc.decodePayloads([]*commonpb.Payload{payload})
decodedPayloads, err := rdc.payloadCodec.Decode([]*commonpb.Payload{payload})
if err != nil {
return err.Error()
}
Expand All @@ -446,57 +519,3 @@ func (rdc *remoteDataConverter) ToStrings(payloads *commonpb.Payloads) []string
}
return strs
}

func (rdc *remoteDataConverter) encodePayloads(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
return rdc.encodeOrDecodePayloads(rdc.options.Endpoint+remotePayloadCodecEncodePath, payloads)
}

func (rdc *remoteDataConverter) decodePayloads(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
return rdc.encodeOrDecodePayloads(rdc.options.Endpoint+remotePayloadCodecDecodePath, payloads)
}

func (rdc *remoteDataConverter) encodeOrDecodePayloads(endpoint string, payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
requestPayloads, err := json.Marshal(commonpb.Payloads{Payloads: payloads})
if err != nil {
return payloads, fmt.Errorf("unable to marshal payloads: %w", err)
}

req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(requestPayloads))
if err != nil {
return payloads, fmt.Errorf("unable to build request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

if rdc.options.ModifyRequest != nil {
err = rdc.options.ModifyRequest(req)
if err != nil {
return payloads, err
}
}

response, err := rdc.options.Client.Do(req)
if err != nil {
return payloads, err
}
defer func() { _ = response.Body.Close() }()

if response.StatusCode == 200 {
bs, err := io.ReadAll(response.Body)
if err != nil {
return payloads, fmt.Errorf("failed to read response body: %w", err)
}
var resultPayloads commonpb.Payloads
err = protojson.Unmarshal(bs, &resultPayloads)
if err != nil {
return payloads, fmt.Errorf("unable to unmarshal payloads: %w", err)
}
if len(payloads) != len(resultPayloads.Payloads) {
return payloads, fmt.Errorf("received %d payloads from remote codec, expected %d", len(resultPayloads.Payloads), len(payloads))
}
return resultPayloads.Payloads, nil
}

message, _ := io.ReadAll(response.Body)
return payloads, fmt.Errorf("%s: %s", http.StatusText(response.StatusCode), message)
}

0 comments on commit 987379d

Please sign in to comment.