diff --git a/converter/codec.go b/converter/codec.go index b1d402c8c..35152306e 100644 --- a/converter/codec.go +++ b/converter/codec.go @@ -347,27 +347,100 @@ func NewPayloadCodecHTTPHandler(e ...PayloadCodec) http.Handler { return &codecHTTPHandler{codecs: e} } -// RemoteDataConverterOptions are options for NewRemoteDataConverter. +// RemotePayloadCodecOptions are options for RemotePayloadCodec. // Client is optional. +type RemotePayloadCodecOptions struct { + Endpoint string + ModifyRequest func(*http.Request) error + Client http.Client +} + +type remotePayloadCodec struct { + options RemotePayloadCodecOptions +} + +// NewRemotePayloadCodec creates a PayloadCodec using the remote endpoint configured by RemotePayloadCodecOptions. +func NewRemotePayloadCodec(options RemotePayloadCodecOptions) PayloadCodec { + return &remotePayloadCodec{options} +} + +// Encode uses the remote payload codec endpoint to encode payloads. +func (pc *remotePayloadCodec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { + return pc.encodeOrDecode(pc.options.Endpoint+remotePayloadCodecEncodePath, payloads) +} + +// Decode uses the remote payload codec endpoint to decode payloads. +func (pc *remotePayloadCodec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { + return pc.encodeOrDecode(pc.options.Endpoint+remotePayloadCodecDecodePath, payloads) +} + +func (pc *remotePayloadCodec) encodeOrDecode(endpoint string, payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { + requestPayloads, err := json.Marshal(commonpb.Payloads{Payloads: payloads}) + if err != nil { + return payloads, fmt.Errorf("unable to marshal payloads: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(requestPayloads)) + if err != nil { + return payloads, fmt.Errorf("unable to build request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + if pc.options.ModifyRequest != nil { + err = pc.options.ModifyRequest(req) + if err != nil { + return payloads, err + } + } + + response, err := pc.options.Client.Do(req) + if err != nil { + return payloads, err + } + defer func() { _ = response.Body.Close() }() + + if response.StatusCode == 200 { + bs, err := io.ReadAll(response.Body) + if err != nil { + return payloads, fmt.Errorf("failed to read response body: %w", err) + } + var resultPayloads commonpb.Payloads + err = protojson.Unmarshal(bs, &resultPayloads) + if err != nil { + return payloads, fmt.Errorf("unable to unmarshal payloads: %w", err) + } + if len(payloads) != len(resultPayloads.Payloads) { + return payloads, fmt.Errorf("received %d payloads from remote codec, expected %d", len(resultPayloads.Payloads), len(payloads)) + } + return resultPayloads.Payloads, nil + } + + message, _ := io.ReadAll(response.Body) + return payloads, fmt.Errorf("%s: %s", http.StatusText(response.StatusCode), message) +} + +// Fields Endpoint, ModifyRequest, Client of RemotePayloadCodecOptions are also +// exposed here in RemoteDataConverterOptions for backwards compatibility. + +// RemoteDataConverterOptions are options for NewRemoteDataConverter. type RemoteDataConverterOptions struct { Endpoint string ModifyRequest func(*http.Request) error Client http.Client } -// remoteDataConverter is a DataConverter that wraps an underlying data -// converter and uses a remote codec to handle encoding/decoding. type remoteDataConverter struct { - parent DataConverter - options RemoteDataConverterOptions + parent DataConverter + payloadCodec PayloadCodec } // NewRemoteDataConverter wraps the given parent DataConverter and performs // encoding/decoding on the payload via the remote endpoint. func NewRemoteDataConverter(parent DataConverter, options RemoteDataConverterOptions) DataConverter { options.Endpoint = strings.TrimSuffix(options.Endpoint, "/") - - return &remoteDataConverter{parent, options} + payloadCodec := NewRemotePayloadCodec(RemotePayloadCodecOptions(options)) + return &remoteDataConverter{parent, payloadCodec} } // ToPayload implements DataConverter.ToPayload performing remote encoding on the @@ -377,7 +450,7 @@ func (rdc *remoteDataConverter) ToPayload(value interface{}) (*commonpb.Payload, if payload == nil || err != nil { return payload, err } - encodedPayloads, err := rdc.encodePayloads([]*commonpb.Payload{payload}) + encodedPayloads, err := rdc.payloadCodec.Encode([]*commonpb.Payload{payload}) if err != nil { return payload, err } @@ -391,14 +464,14 @@ func (rdc *remoteDataConverter) ToPayloads(value ...interface{}) (*commonpb.Payl if payloads == nil || err != nil { return payloads, err } - encodedPayloads, err := rdc.encodePayloads(payloads.Payloads) + encodedPayloads, err := rdc.payloadCodec.Encode(payloads.Payloads) return &commonpb.Payloads{Payloads: encodedPayloads}, err } // FromPayload implements DataConverter.FromPayload performing remote decoding on the // given payload before sending to the parent FromPayload. func (rdc *remoteDataConverter) FromPayload(payload *commonpb.Payload, valuePtr interface{}) error { - decodedPayloads, err := rdc.decodePayloads([]*commonpb.Payload{payload}) + decodedPayloads, err := rdc.payloadCodec.Decode([]*commonpb.Payload{payload}) if err != nil { return err } @@ -412,7 +485,7 @@ func (rdc *remoteDataConverter) FromPayloads(payloads *commonpb.Payloads, valueP return rdc.parent.FromPayloads(payloads, valuePtrs...) } - decodedPayloads, err := rdc.decodePayloads(payloads.Payloads) + decodedPayloads, err := rdc.payloadCodec.Decode(payloads.Payloads) if err != nil { return err } @@ -426,7 +499,7 @@ func (rdc *remoteDataConverter) ToString(payload *commonpb.Payload) string { return rdc.parent.ToString(payload) } - decodedPayloads, err := rdc.decodePayloads([]*commonpb.Payload{payload}) + decodedPayloads, err := rdc.payloadCodec.Decode([]*commonpb.Payload{payload}) if err != nil { return err.Error() } @@ -446,57 +519,3 @@ func (rdc *remoteDataConverter) ToStrings(payloads *commonpb.Payloads) []string } return strs } - -func (rdc *remoteDataConverter) encodePayloads(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { - return rdc.encodeOrDecodePayloads(rdc.options.Endpoint+remotePayloadCodecEncodePath, payloads) -} - -func (rdc *remoteDataConverter) decodePayloads(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { - return rdc.encodeOrDecodePayloads(rdc.options.Endpoint+remotePayloadCodecDecodePath, payloads) -} - -func (rdc *remoteDataConverter) encodeOrDecodePayloads(endpoint string, payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { - requestPayloads, err := json.Marshal(commonpb.Payloads{Payloads: payloads}) - if err != nil { - return payloads, fmt.Errorf("unable to marshal payloads: %w", err) - } - - req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(requestPayloads)) - if err != nil { - return payloads, fmt.Errorf("unable to build request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - if rdc.options.ModifyRequest != nil { - err = rdc.options.ModifyRequest(req) - if err != nil { - return payloads, err - } - } - - response, err := rdc.options.Client.Do(req) - if err != nil { - return payloads, err - } - defer func() { _ = response.Body.Close() }() - - if response.StatusCode == 200 { - bs, err := io.ReadAll(response.Body) - if err != nil { - return payloads, fmt.Errorf("failed to read response body: %w", err) - } - var resultPayloads commonpb.Payloads - err = protojson.Unmarshal(bs, &resultPayloads) - if err != nil { - return payloads, fmt.Errorf("unable to unmarshal payloads: %w", err) - } - if len(payloads) != len(resultPayloads.Payloads) { - return payloads, fmt.Errorf("received %d payloads from remote codec, expected %d", len(resultPayloads.Payloads), len(payloads)) - } - return resultPayloads.Payloads, nil - } - - message, _ := io.ReadAll(response.Body) - return payloads, fmt.Errorf("%s: %s", http.StatusText(response.StatusCode), message) -}