Skip to content

Commit

Permalink
feat(codec): Unknown Service Handler (#1321)
Browse files Browse the repository at this point in the history
  • Loading branch information
lokistars committed Aug 19, 2024
1 parent 0824d3c commit e13869d
Show file tree
Hide file tree
Showing 6 changed files with 469 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pkg/remote/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"net"
"time"

"github.com/cloudwego/kitex/pkg/unknownservice/service"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/profiler"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
Expand Down Expand Up @@ -113,6 +115,8 @@ type ServerOption struct {

GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error

UnknownServiceHandler service.UnknownServiceHandler

Option

// invoking chain with recv/send middlewares for streaming APIs
Expand Down
85 changes: 85 additions & 0 deletions pkg/unknownservice/service/unknown_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2021 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package service

import (
"context"

"github.com/cloudwego/kitex/pkg/serviceinfo"
)

const (
// UnknownService name
UnknownService = "$UnknownService" // private as "$"
// UnknownMethod name
UnknownMethod = "$UnknownMethod"
)

type Args struct {
Request []byte
Method string
ServiceName string
}

type Result struct {
Success []byte
Method string
ServiceName string
}

type UnknownServiceHandler interface {
UnknownServiceHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error)
}

// NewServiceInfo create serviceInfo
func NewServiceInfo(pcType serviceinfo.PayloadCodec, service, method string) *serviceinfo.ServiceInfo {
methods := map[string]serviceinfo.MethodInfo{
method: serviceinfo.NewMethodInfo(callHandler, newServiceArgs, newServiceResult, false),
}
handlerType := (*UnknownServiceHandler)(nil)

svcInfo := &serviceinfo.ServiceInfo{
ServiceName: service,
HandlerType: handlerType,
Methods: methods,
PayloadCodec: pcType,
Extra: make(map[string]interface{}),
}

return svcInfo
}

func callHandler(ctx context.Context, handler, arg, result interface{}) error {
realArg := arg.(*Args)
realResult := result.(*Result)
realResult.Method = realArg.Method
realResult.ServiceName = realArg.ServiceName
success, err := handler.(UnknownServiceHandler).UnknownServiceHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request)
if err != nil {
return err
}
realResult.Success = success
return nil
}

func newServiceArgs() interface{} {
return &Args{}
}

func newServiceResult() interface{} {
return &Result{}
}
236 changes: 236 additions & 0 deletions pkg/unknownservice/unknown.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
/*
* Copyright 2021 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package unknownservice

import (
"context"
"encoding/binary"
"errors"
"fmt"

"github.com/cloudwego/kitex/pkg/protocol/bthrift"
thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache"

"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service"
)

// UnknownCodec implements PayloadCodec
type unknownCodec struct {
Codec remote.PayloadCodec
}

// NewUnknownServiceCodec creates the unknown binary codec.
func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec {
return &unknownCodec{code}
}

// Marshal implements the remote.PayloadCodec interface.
func (c unknownCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
ink := msg.RPCInfo().Invocation()
data := msg.Data()

res, ok := data.(*unknownservice.Result)
if !ok {
return c.Codec.Marshal(ctx, msg, out)
}
if len(res.Success) == 0 {
return errors.New("unknown messages cannot be empty")
}
if msg.MessageType() == remote.Exception {
return c.Codec.Marshal(ctx, msg, out)
}
if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
ink.SetMethodName(res.Method)
ink.SetServiceName(res.ServiceName)
} else {
return errors.New("the interface Invocation doesn't implement InvocationSetter")
}
if err := encode(res, msg, out); err != nil {
return c.Codec.Marshal(ctx, msg, out)
}
return nil
}

// Unmarshal implements the remote.PayloadCodec interface.
func (c unknownCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
ink := message.RPCInfo().Invocation()
magicAndMsgType, err := codec.PeekUint32(in)
if err != nil {
return err
}
msgType := magicAndMsgType & codec.FrontMask
if msgType == uint32(remote.Exception) {
return c.Codec.Unmarshal(ctx, message, in)
}
if err = codec.UpdateMsgType(msgType, message); err != nil {
return err
}
service, method, err := readDecode(message, in)
if err != nil {
return err
}
err = codec.SetOrCheckMethodName(method, message)
var te *remote.TransError
if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) {
svcInfo, err := message.SpecifyServiceInfo(unknownservice.UnknownService, unknownservice.UnknownMethod)
if err != nil {
return err
}

if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
ink.SetMethodName(unknownservice.UnknownMethod)
ink.SetPackageName(svcInfo.GetPackageName())
ink.SetServiceName(unknownservice.UnknownService)
} else {
return errors.New("the interface Invocation doesn't implement InvocationSetter")
}
if err = codec.NewDataIfNeeded(unknownservice.UnknownMethod, message); err != nil {
return err
}

data := message.Data()

if data, ok := data.(*unknownservice.Args); ok {
data.Method = method
data.ServiceName = service
buf, err := in.Next(in.ReadableLen())
if err != nil {
return err
}
data.Request = buf
}
return nil
}

return c.Codec.Unmarshal(ctx, message, in)
}

// Name implements the remote.PayloadCodec interface.
func (c unknownCodec) Name() string {
return "unknownMethodCodec"
}

func write(dst, src []byte) {
copy(dst, src)
}

func readDecode(message remote.Message, in remote.ByteBuffer) (string, string, error) {
code := message.ProtocolInfo().CodecType
if code == serviceinfo.Thrift || code == serviceinfo.Protobuf {
method, size, err := peekMethod(in)
if err != nil {
return "", "", err
}

seqID, err := peekSeqID(in, size)
if err != nil {
return "", "", err
}
if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
return "", "", err
}
return message.RPCInfo().Invocation().ServiceName(), method, nil
}
return "", "", nil
}

func peekMethod(in remote.ByteBuffer) (string, int32, error) {
buf, err := in.Peek(8)
if err != nil {
return "", 0, err
}
buf = buf[4:]
size := int32(binary.BigEndian.Uint32(buf))
buf, err = in.Peek(int(size + 8))
if err != nil {
return "", 0, perrors.NewProtocolError(err)
}
buf = buf[8:]
method := string(buf)
return method, size + 8, nil
}

func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) {
buf, err := in.Peek(int(size + 4))
if err != nil {
return 0, perrors.NewProtocolError(err)
}
buf = buf[size:]
seqID := int32(binary.BigEndian.Uint32(buf))
return seqID, nil
}

func encode(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error {
if msg.ProtocolInfo().CodecType == serviceinfo.Thrift {
return encodeThrift(res, msg, out)
}
if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf {
return encodeKitexProtobuf(res, msg, out)
}
return nil
}

// encodeThrift Thrift encoder
func encodeThrift(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error {
nw, _ := out.(remote.NocopyWrite)
msgType := msg.MessageType()
ink := msg.RPCInfo().Invocation()
msgBeginLen := bthrift.Binary.MessageBeginLength(res.Method, thrift.TMessageType(msgType), ink.SeqID())
msgEndLen := bthrift.Binary.MessageEndLength()

buf, err := out.Malloc(msgBeginLen + len(res.Success) + msgEndLen)
if err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error()))
}
offset := bthrift.Binary.WriteMessageBegin(buf, res.Method, thrift.TMessageType(msgType), ink.SeqID())
write(buf[offset:], res.Success)
bthrift.Binary.WriteMessageEnd(buf[offset:])
if nw == nil {
// if nw is nil, FastWrite will act in Copy mode.
return nil
}
return nw.MallocAck(out.MallocLen())
}

// encodeKitexProtobuf Kitex Protobuf encoder
func encodeKitexProtobuf(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error {
ink := msg.RPCInfo().Invocation()
// 3.1 magic && msgType
if err := codec.WriteUint32(codec.ProtobufV1Magic+uint32(msg.MessageType()), out); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write meta info failed: %s", err.Error()))
}
// 3.2 methodName
if _, err := codec.WriteString(res.Method, out); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write method name failed: %s", err.Error()))
}
// 3.3 seqID
if err := codec.WriteUint32(uint32(ink.SeqID()), out); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write seqID failed: %s", err.Error()))
}
dataLen := len(res.Success)
buf, err := out.Malloc(dataLen)
if err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf malloc size %d failed: %s", dataLen, err.Error()))
}
write(buf, res.Success)
return nil
}
Loading

0 comments on commit e13869d

Please sign in to comment.