From 50e9b17d9faa8769289ad1236bc5b439f0377a99 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 12 Jun 2024 16:05:57 +0800 Subject: [PATCH 01/70] refactor: move apache code to separated pkg (#1381) Co-authored-by: QihengZhou Co-authored-by: Yi Duan --- pkg/generic/binary_test/generic_init.go | 3 +- pkg/generic/binary_test/generic_test.go | 3 +- pkg/generic/binarythrift_codec_test.go | 3 +- pkg/generic/descriptor/type.go | 2 +- pkg/generic/generic_service.go | 3 +- pkg/generic/reflect_test/reflect_test.go | 2 +- pkg/generic/thrift/base.go | 2 +- pkg/generic/thrift/http.go | 2 +- pkg/generic/thrift/http_fallback.go | 2 +- pkg/generic/thrift/http_go116plus_amd64.go | 2 +- pkg/generic/thrift/http_pb.go | 2 +- pkg/generic/thrift/json.go | 2 +- pkg/generic/thrift/json_fallback.go | 2 +- pkg/generic/thrift/json_go116plus_amd64.go | 2 +- pkg/generic/thrift/read.go | 2 +- pkg/generic/thrift/read_test.go | 2 +- pkg/generic/thrift/struct.go | 3 +- pkg/generic/thrift/thrift.go | 2 +- pkg/generic/thrift/write.go | 2 +- pkg/generic/thrift/write_test.go | 2 +- pkg/protocol/bthrift/apache/apache.go | 37 ++++++ .../bthrift/apache/application_exception.go | 39 ++++++ .../bthrift/apache/binary_protocol.go | 23 ++++ pkg/protocol/bthrift/apache/exception.go | 28 ++++ pkg/protocol/bthrift/apache/memory_buffer.go | 57 ++++++++ pkg/protocol/bthrift/apache/messagetype.go | 32 +++++ pkg/protocol/bthrift/apache/protocol.go | 33 +++++ .../bthrift/apache/protocol_exception.go | 33 +++++ pkg/protocol/bthrift/apache/serializer.go | 24 ++++ pkg/protocol/bthrift/apache/transport.go | 23 ++++ pkg/protocol/bthrift/apache/type.go | 43 ++++++ pkg/protocol/bthrift/binary.go | 11 +- pkg/protocol/bthrift/binary_test.go | 14 +- pkg/protocol/bthrift/exception.go | 122 ++++++++++++++++++ pkg/protocol/bthrift/exception_test.go | 47 +++++++ pkg/protocol/bthrift/interface.go | 2 +- pkg/protocol/bthrift/unknown.go | 3 +- pkg/protocol/bthrift/utils.go | 38 ++++++ pkg/remote/codec/thrift/binary_protocol.go | 3 +- .../codec/thrift/binary_protocol_test.go | 3 +- pkg/remote/codec/thrift/skip_decoder.go | 3 +- pkg/remote/codec/thrift/skip_decoder_test.go | 3 +- pkg/remote/codec/thrift/thrift.go | 3 +- pkg/remote/codec/thrift/thrift_data.go | 3 +- pkg/remote/codec/thrift/thrift_data_test.go | 3 +- pkg/remote/codec/thrift/thrift_frugal.go | 2 +- pkg/remote/codec/thrift/thrift_test.go | 2 +- pkg/remote/trans/netpollmux/control_frame.go | 2 +- pkg/utils/thrift.go | 61 ++++----- pkg/utils/thrift_test.go | 6 +- 50 files changed, 651 insertions(+), 97 deletions(-) create mode 100644 pkg/protocol/bthrift/apache/apache.go create mode 100644 pkg/protocol/bthrift/apache/application_exception.go create mode 100644 pkg/protocol/bthrift/apache/binary_protocol.go create mode 100644 pkg/protocol/bthrift/apache/exception.go create mode 100644 pkg/protocol/bthrift/apache/memory_buffer.go create mode 100644 pkg/protocol/bthrift/apache/messagetype.go create mode 100644 pkg/protocol/bthrift/apache/protocol.go create mode 100644 pkg/protocol/bthrift/apache/protocol_exception.go create mode 100644 pkg/protocol/bthrift/apache/serializer.go create mode 100644 pkg/protocol/bthrift/apache/transport.go create mode 100644 pkg/protocol/bthrift/apache/type.go create mode 100644 pkg/protocol/bthrift/exception.go create mode 100644 pkg/protocol/bthrift/exception_test.go create mode 100644 pkg/protocol/bthrift/utils.go diff --git a/pkg/generic/binary_test/generic_init.go b/pkg/generic/binary_test/generic_init.go index 37ddc9b064..d7f5bd4418 100644 --- a/pkg/generic/binary_test/generic_init.go +++ b/pkg/generic/binary_test/generic_init.go @@ -25,14 +25,13 @@ import ( "net" "time" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/genericclient" kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/kerrors" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" diff --git a/pkg/generic/binary_test/generic_test.go b/pkg/generic/binary_test/generic_test.go index c5952291d2..f5d7613b64 100644 --- a/pkg/generic/binary_test/generic_test.go +++ b/pkg/generic/binary_test/generic_test.go @@ -26,8 +26,6 @@ import ( "testing" "time" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/callopt" "github.com/cloudwego/kitex/client/genericclient" @@ -35,6 +33,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/kerrors" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/server" ) diff --git a/pkg/generic/binarythrift_codec_test.go b/pkg/generic/binarythrift_codec_test.go index f67d82cec0..5393b16efc 100644 --- a/pkg/generic/binarythrift_codec_test.go +++ b/pkg/generic/binarythrift_codec_test.go @@ -20,10 +20,9 @@ import ( "context" "testing" - "github.com/apache/thrift/lib/go/thrift" - kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/descriptor/type.go b/pkg/generic/descriptor/type.go index e375d8989b..28a49a6bef 100644 --- a/pkg/generic/descriptor/type.go +++ b/pkg/generic/descriptor/type.go @@ -16,7 +16,7 @@ package descriptor -import "github.com/apache/thrift/lib/go/thrift" +import thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" // Type constants in the Thrift protocol type Type byte diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index 2ff40e6182..5655fd3fad 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -20,10 +20,9 @@ import ( "context" "fmt" - "github.com/apache/thrift/lib/go/thrift" - gproto "github.com/cloudwego/kitex/pkg/generic/proto" gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/reflect_test/reflect_test.go b/pkg/generic/reflect_test/reflect_test.go index 0e9b1b103b..36ec671475 100644 --- a/pkg/generic/reflect_test/reflect_test.go +++ b/pkg/generic/reflect_test/reflect_test.go @@ -34,10 +34,10 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/klog" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" - "github.com/apache/thrift/lib/go/thrift" dt "github.com/cloudwego/dynamicgo/thrift" dg "github.com/cloudwego/dynamicgo/thrift/generic" ) diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go index 8d5b55b346..8a3f05febb 100644 --- a/pkg/generic/thrift/base.go +++ b/pkg/generic/thrift/base.go @@ -19,7 +19,7 @@ package thrift import ( "fmt" - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type TrafficEnv struct { diff --git a/pkg/generic/thrift/http.go b/pkg/generic/thrift/http.go index 276c69a5d7..9d7b66e8ca 100644 --- a/pkg/generic/thrift/http.go +++ b/pkg/generic/thrift/http.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/apache/thrift/lib/go/thrift" "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" @@ -29,6 +28,7 @@ import ( "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" diff --git a/pkg/generic/thrift/http_fallback.go b/pkg/generic/thrift/http_fallback.go index c322ad47b0..0d4ae7a2e3 100644 --- a/pkg/generic/thrift/http_fallback.go +++ b/pkg/generic/thrift/http_fallback.go @@ -22,7 +22,7 @@ package thrift import ( "context" - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // Write ... diff --git a/pkg/generic/thrift/http_go116plus_amd64.go b/pkg/generic/thrift/http_go116plus_amd64.go index a81bf08eca..3b6a82683e 100644 --- a/pkg/generic/thrift/http_go116plus_amd64.go +++ b/pkg/generic/thrift/http_go116plus_amd64.go @@ -23,13 +23,13 @@ import ( "context" "unsafe" - "github.com/apache/thrift/lib/go/thrift" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/j2t" "github.com/cloudwego/dynamicgo/thrift/base" "github.com/cloudwego/kitex/pkg/generic/descriptor" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" ) diff --git a/pkg/generic/thrift/http_pb.go b/pkg/generic/thrift/http_pb.go index d4051731a0..f0b7418088 100644 --- a/pkg/generic/thrift/http_pb.go +++ b/pkg/generic/thrift/http_pb.go @@ -21,12 +21,12 @@ import ( "errors" "fmt" - "github.com/apache/thrift/lib/go/thrift" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/dynamic" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/generic/proto" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // WriteHTTPPbRequest implement of MessageWriter diff --git a/pkg/generic/thrift/json.go b/pkg/generic/thrift/json.go index e769d3f76a..bcb9b83471 100644 --- a/pkg/generic/thrift/json.go +++ b/pkg/generic/thrift/json.go @@ -21,7 +21,6 @@ import ( "fmt" "strconv" - "github.com/apache/thrift/lib/go/thrift" "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" @@ -31,6 +30,7 @@ import ( "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" diff --git a/pkg/generic/thrift/json_fallback.go b/pkg/generic/thrift/json_fallback.go index 5b333435ec..c05c548993 100644 --- a/pkg/generic/thrift/json_fallback.go +++ b/pkg/generic/thrift/json_fallback.go @@ -22,7 +22,7 @@ package thrift import ( "context" - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // Write write json string to out thrift.TProtocol diff --git a/pkg/generic/thrift/json_go116plus_amd64.go b/pkg/generic/thrift/json_go116plus_amd64.go index 7b28c5784f..a8515ba5f5 100644 --- a/pkg/generic/thrift/json_go116plus_amd64.go +++ b/pkg/generic/thrift/json_go116plus_amd64.go @@ -23,7 +23,6 @@ import ( "context" "unsafe" - "github.com/apache/thrift/lib/go/thrift" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/j2t" @@ -31,6 +30,7 @@ import ( "github.com/cloudwego/dynamicgo/thrift/base" "github.com/cloudwego/kitex/pkg/generic/descriptor" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/utils" diff --git a/pkg/generic/thrift/read.go b/pkg/generic/thrift/read.go index 6b88881826..649b349416 100644 --- a/pkg/generic/thrift/read.go +++ b/pkg/generic/thrift/read.go @@ -22,11 +22,11 @@ import ( "fmt" "reflect" - "github.com/apache/thrift/lib/go/thrift" "github.com/jhump/protoreflect/desc" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/generic/proto" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) var emptyPbDsc = &desc.MessageDescriptor{} diff --git a/pkg/generic/thrift/read_test.go b/pkg/generic/thrift/read_test.go index b996970c4b..52f5e249db 100644 --- a/pkg/generic/thrift/read_test.go +++ b/pkg/generic/thrift/read_test.go @@ -24,13 +24,13 @@ import ( "reflect" "testing" - "github.com/apache/thrift/lib/go/thrift" "github.com/jhump/protoreflect/desc/protoparse" "github.com/stretchr/testify/require" "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/generic/proto" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) var ( diff --git a/pkg/generic/thrift/struct.go b/pkg/generic/thrift/struct.go index 63f0daf2a0..d3815cc690 100644 --- a/pkg/generic/thrift/struct.go +++ b/pkg/generic/thrift/struct.go @@ -19,9 +19,8 @@ package thrift import ( "context" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/pkg/generic/descriptor" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // NewWriteStruct ... diff --git a/pkg/generic/thrift/thrift.go b/pkg/generic/thrift/thrift.go index 7bbc9402c1..aebf71d23e 100644 --- a/pkg/generic/thrift/thrift.go +++ b/pkg/generic/thrift/thrift.go @@ -20,7 +20,7 @@ package thrift import ( "context" - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) const ( diff --git a/pkg/generic/thrift/write.go b/pkg/generic/thrift/write.go index bdab9b6bbd..e7bcdc9bd0 100644 --- a/pkg/generic/thrift/write.go +++ b/pkg/generic/thrift/write.go @@ -22,11 +22,11 @@ import ( "encoding/json" "fmt" - "github.com/apache/thrift/lib/go/thrift" "github.com/tidwall/gjson" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/generic/proto" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) diff --git a/pkg/generic/thrift/write_test.go b/pkg/generic/thrift/write_test.go index b3aa230085..d71e40c444 100644 --- a/pkg/generic/thrift/write_test.go +++ b/pkg/generic/thrift/write_test.go @@ -26,7 +26,6 @@ import ( "reflect" "testing" - "github.com/apache/thrift/lib/go/thrift" "github.com/jhump/protoreflect/desc/protoparse" "github.com/tidwall/gjson" @@ -34,6 +33,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/generic/proto" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) func Test_nextWriter(t *testing.T) { diff --git a/pkg/protocol/bthrift/apache/apache.go b/pkg/protocol/bthrift/apache/apache.go new file mode 100644 index 0000000000..ca7ee5d96c --- /dev/null +++ b/pkg/protocol/bthrift/apache/apache.go @@ -0,0 +1,37 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package apache contains codes originally from https://github.com/apache/thrift. +// +// we're planning to get rid of the pkg, here are steps we're going to work on: +// 1. Remove unnecessary dependencies of apache from kitex +// 2. Move all apache dependencies to this pkg, mainly types, interfaces and consts +// - We may use type alias at the beginning for better compatibility +// - Mark interfaces as `Deprecated` since we no longer use it in the future, and we have better implementation. +// 3. For internal dependencies of apache, new alternative implementation will be in: +// - pkg/protocol/bthrift -> low level encoding or decoding bytes +// - pkg/remote/codec/thrift -> high level interfaces +// 4. Change necessary dependencies to this file, including code generator +// (After a period of time) +// 5. Remove apache support of code generator (mainly interfaces) +// 6. Remove type alias and move definition to this file. This may causes compatible issues which are expected. +// - legacy code generator should use legacy version of kitex, then should not have compatibility issue. +// (After a period of time) +// 7. Remove interfaces like thrift.TProtocol from this file +// 8. Done +// +// Now we're working on step 1 - 4. +package apache diff --git a/pkg/protocol/bthrift/apache/application_exception.go b/pkg/protocol/bthrift/apache/application_exception.go new file mode 100644 index 0000000000..ac02e24cef --- /dev/null +++ b/pkg/protocol/bthrift/apache/application_exception.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/application_exception.go + +const ( + UNKNOWN_APPLICATION_EXCEPTION = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE_EXCEPTION = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 +) + +type TApplicationException = thrift.TApplicationException + +var NewTApplicationException = thrift.NewTApplicationException diff --git a/pkg/protocol/bthrift/apache/binary_protocol.go b/pkg/protocol/bthrift/apache/binary_protocol.go new file mode 100644 index 0000000000..2a2a4538b2 --- /dev/null +++ b/pkg/protocol/bthrift/apache/binary_protocol.go @@ -0,0 +1,23 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/binary_protocol.go + +var NewTBinaryProtocol = thrift.NewTBinaryProtocol diff --git a/pkg/protocol/bthrift/apache/exception.go b/pkg/protocol/bthrift/apache/exception.go new file mode 100644 index 0000000000..2a0a1f67ff --- /dev/null +++ b/pkg/protocol/bthrift/apache/exception.go @@ -0,0 +1,28 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go + +// Generic Thrift exception +type TException interface { + error +} + +var PrependError = thrift.PrependError diff --git a/pkg/protocol/bthrift/apache/memory_buffer.go b/pkg/protocol/bthrift/apache/memory_buffer.go new file mode 100644 index 0000000000..10a0af751f --- /dev/null +++ b/pkg/protocol/bthrift/apache/memory_buffer.go @@ -0,0 +1,57 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import ( + "bytes" + "context" +) + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/memory_buffer.go + +// Memory buffer-based implementation of the TTransport interface. +type TMemoryBuffer struct { + *bytes.Buffer + size int +} + +func NewTMemoryBufferLen(size int) *TMemoryBuffer { + buf := make([]byte, 0, size) + return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} +} + +func (p *TMemoryBuffer) IsOpen() bool { + return true +} + +func (p *TMemoryBuffer) Open() error { + return nil +} + +func (p *TMemoryBuffer) Close() error { + p.Buffer.Reset() + return nil +} + +// Flushing a memory buffer is a no-op +func (p *TMemoryBuffer) Flush(ctx context.Context) error { + return nil +} + +func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) { + return uint64(p.Buffer.Len()) +} diff --git a/pkg/protocol/bthrift/apache/messagetype.go b/pkg/protocol/bthrift/apache/messagetype.go new file mode 100644 index 0000000000..1885144aee --- /dev/null +++ b/pkg/protocol/bthrift/apache/messagetype.go @@ -0,0 +1,32 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/messagetype.go + +// Message type constants in the Thrift protocol. +type TMessageType = thrift.TMessageType + +const ( + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) diff --git a/pkg/protocol/bthrift/apache/protocol.go b/pkg/protocol/bthrift/apache/protocol.go new file mode 100644 index 0000000000..9d0a991d96 --- /dev/null +++ b/pkg/protocol/bthrift/apache/protocol.go @@ -0,0 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol.go + +const ( + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 +) + +type TProtocol = thrift.TProtocol + +// The maximum recursive depth the skip() function will traverse +const DEFAULT_RECURSION_DEPTH = 64 + +var SkipDefaultDepth = thrift.SkipDefaultDepth diff --git a/pkg/protocol/bthrift/apache/protocol_exception.go b/pkg/protocol/bthrift/apache/protocol_exception.go new file mode 100644 index 0000000000..7b020797f5 --- /dev/null +++ b/pkg/protocol/bthrift/apache/protocol_exception.go @@ -0,0 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol_exception.go + +var NewTProtocolExceptionWithType = thrift.NewTProtocolExceptionWithType + +const ( + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) diff --git a/pkg/protocol/bthrift/apache/serializer.go b/pkg/protocol/bthrift/apache/serializer.go new file mode 100644 index 0000000000..c255250301 --- /dev/null +++ b/pkg/protocol/bthrift/apache/serializer.go @@ -0,0 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/serializer.go + +type TStruct interface { + Write(p TProtocol) error + Read(p TProtocol) error +} diff --git a/pkg/protocol/bthrift/apache/transport.go b/pkg/protocol/bthrift/apache/transport.go new file mode 100644 index 0000000000..25a752ae52 --- /dev/null +++ b/pkg/protocol/bthrift/apache/transport.go @@ -0,0 +1,23 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/transport.go + +type TTransport = thrift.TTransport diff --git a/pkg/protocol/bthrift/apache/type.go b/pkg/protocol/bthrift/apache/type.go new file mode 100644 index 0000000000..42533b085e --- /dev/null +++ b/pkg/protocol/bthrift/apache/type.go @@ -0,0 +1,43 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/type.go + +type TType = thrift.TType + +const ( + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 +) diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index 7f54a253e7..4eaba2b900 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -23,11 +23,9 @@ import ( "fmt" "math" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/pkg/mem" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - "github.com/cloudwego/kitex/pkg/utils" ) var ( @@ -161,7 +159,7 @@ func (binaryProtocol) WriteBinary(buf, value []byte) int { } func (binaryProtocol) WriteStringNocopy(buf []byte, binaryWriter BinaryWriter, value string) int { - return Binary.WriteBinaryNocopy(buf, binaryWriter, utils.StringToSliceByte(value)) + return Binary.WriteBinaryNocopy(buf, binaryWriter, stringToSliceByte(value)) } func (binaryProtocol) WriteBinaryNocopy(buf []byte, binaryWriter BinaryWriter, value []byte) int { @@ -267,7 +265,7 @@ func (binaryProtocol) BinaryLength(value []byte) int { } func (binaryProtocol) StringLengthNocopy(value string) int { - return Binary.BinaryLengthNocopy(utils.StringToSliceByte(value)) + return Binary.BinaryLengthNocopy(stringToSliceByte(value)) } func (binaryProtocol) BinaryLengthNocopy(value []byte) int { @@ -484,8 +482,7 @@ func (binaryProtocol) ReadString(buf []byte) (value string, length int, err erro } alloc := allocator if alloc != nil { - data := alloc.Copy(buf[length : length+int(size)]) - value = utils.SliceByteToString(data) + value = sliceByteToString(alloc.Copy(buf[length : length+int(size)])) } else { value = string(buf[length : length+int(size)]) } diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 43df81b165..8e395bb5cd 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -21,12 +21,8 @@ import ( "fmt" "testing" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/netpoll" - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/remote" - internalnetpoll "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // TestWriteMessageEnd test binary WriteMessageEnd function @@ -358,9 +354,7 @@ func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) exceptWs := "0000000c6d657373616765426567696e" exceptSize := 16 - out := internalnetpoll.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(0)) - nw, _ := out.(remote.NocopyWrite) - wn := Binary.WriteStringNocopy(buf, nw, "messageBegin") + wn := Binary.WriteStringNocopy(buf, nil, "messageBegin") ws := fmt.Sprintf("%x", buf[:wn]) test.Assert(t, wn == exceptSize, wn, exceptSize) test.Assert(t, ws == exceptWs, ws, exceptWs) @@ -371,9 +365,7 @@ func TestWriteBinaryNocopy(t *testing.T) { buf := make([]byte, 128) exceptWs := "0000000c6d657373616765426567696e" exceptSize := 16 - out := internalnetpoll.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(0)) - nw, _ := out.(remote.NocopyWrite) - wn := Binary.WriteBinaryNocopy(buf, nw, []byte("messageBegin")) + wn := Binary.WriteBinaryNocopy(buf, nil, []byte("messageBegin")) ws := fmt.Sprintf("%x", buf[:wn]) test.Assert(t, wn == exceptSize, wn, exceptSize) test.Assert(t, ws == exceptWs, ws, exceptWs) diff --git a/pkg/protocol/bthrift/exception.go b/pkg/protocol/bthrift/exception.go new file mode 100644 index 0000000000..d229dd9eb0 --- /dev/null +++ b/pkg/protocol/bthrift/exception.go @@ -0,0 +1,122 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bthrift + +import ( + "fmt" + + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" +) + +// ApplicationException represents the application exception decoder for replacing apache.TApplicationException +// it implements ThriftMsgFastCodec interface. +type ApplicationException struct { + t int32 + m string +} + +// NewApplicationException creates an ApplicationException instance +func NewApplicationException(t int32, msg string) *ApplicationException { + return &ApplicationException{t: t, m: msg} +} + +// Msg ... +func (e *ApplicationException) Msg() string { return e.m } + +// TypeID ... +func (e *ApplicationException) TypeID() int32 { return e.t } + +// BLength returns the len of encoded buffer. +func (e *ApplicationException) BLength() int { + // Msg Field: 1 (type) + 2 (id) + 4(strlen) + len(m) + // Type Field: 1 (type) + 2 (id) + 4(ex type) + // STOP: 1 byte + return (1 + 2 + 4 + len(e.m)) + (1 + 2 + 4) + 1 +} + +// Read ... +func (e *ApplicationException) FastRead(b []byte) (off int, err error) { + for i := 0; i < 2; i++ { + _, tp, id, l, err := Binary.ReadFieldBegin(b[off:]) + if err != nil { + return 0, err + } + off += l + switch { + case id == 1 && tp == thrift.STRING: // Msg + e.m, l, err = Binary.ReadString(b[off:]) + case id == 2 && tp == thrift.I32: // TypeID + e.t, l, err = Binary.ReadI32(b[off:]) + default: + l, err = Binary.Skip(b, tp) + } + if err != nil { + return 0, err + } + off += l + } + v, l, err := Binary.ReadByte(b[off:]) + if err != nil { + return 0, err + } + if v != thrift.STOP { + return 0, fmt.Errorf("expects thrift.STOP, found: %d", v) + } + off += l + return off, nil +} + +// Write ... +func (e *ApplicationException) FastWrite(b []byte) (off int) { + off += Binary.WriteFieldBegin(b[off:], "", thrift.STRING, 1) + off += Binary.WriteString(b[off:], e.m) + off += Binary.WriteFieldBegin(b[off:], "", thrift.I32, 2) + off += Binary.WriteI32(b[off:], e.t) + off += Binary.WriteByte(b[off:], thrift.STOP) + return off +} + +// FastWriteNocopy ... XXX: we deprecated XXXNocopy, simply using FastWrite is OK. +func (e *ApplicationException) FastWriteNocopy(b []byte, binaryWriter BinaryWriter) int { + return e.FastWrite(b) +} + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go +var defaultApplicationExceptionMessage = map[int32]string{ + thrift.UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception", + thrift.UNKNOWN_METHOD: "unknown method", + thrift.INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type", + thrift.WRONG_METHOD_NAME: "wrong method name", + thrift.BAD_SEQUENCE_ID: "bad sequence ID", + thrift.MISSING_RESULT: "missing result", + thrift.INTERNAL_ERROR: "unknown internal error", + thrift.PROTOCOL_ERROR: "unknown protocol error", + thrift.INVALID_TRANSFORM: "Invalid transform", + thrift.INVALID_PROTOCOL: "Invalid protocol", + thrift.UNSUPPORTED_CLIENT_TYPE: "Unsupported client type", +} + +// Error implements apache.Exception +func (e *ApplicationException) Error() string { + if e.m != "" { + return e.m + } + if m, ok := defaultApplicationExceptionMessage[e.t]; ok { + return m + } + return fmt.Sprintf("unknown exception type [%d]", e.t) +} diff --git a/pkg/protocol/bthrift/exception_test.go b/pkg/protocol/bthrift/exception_test.go new file mode 100644 index 0000000000..25c25e4cbe --- /dev/null +++ b/pkg/protocol/bthrift/exception_test.go @@ -0,0 +1,47 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bthrift + +import ( + "bytes" + "testing" + + "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" +) + +func TestApplicationException(t *testing.T) { + ex1 := NewApplicationException(1, "t1") + b := make([]byte, ex1.BLength()) + n := ex1.FastWrite(b) + test.Assert(t, n == len(b)) + + ex2 := NewApplicationException(0, "") + n, err := ex2.FastRead(b) + test.Assert(t, err == nil, err) + test.Assert(t, n == len(b), n) + test.Assert(t, ex2.TypeID() == 1) + test.Assert(t, ex2.Msg() == "t1") + + // compatibility test only, can be removed in the future + trans := thrift.NewTMemoryBufferLen(100) + proto := thrift.NewTBinaryProtocol(trans, true, true) + ex0 := thrift.NewTApplicationException(1, "t1") + err = ex0.Write(proto) + test.Assert(t, err == nil, err) + test.Assert(t, bytes.Equal(b, trans.Bytes())) +} diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index 01890b93ba..f3922b1e64 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -18,7 +18,7 @@ package bthrift import ( - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // BinaryWriter . diff --git a/pkg/protocol/bthrift/unknown.go b/pkg/protocol/bthrift/unknown.go index cea1de81c8..ccc133bb6b 100644 --- a/pkg/protocol/bthrift/unknown.go +++ b/pkg/protocol/bthrift/unknown.go @@ -21,8 +21,9 @@ import ( "fmt" "reflect" - "github.com/apache/thrift/lib/go/thrift" "github.com/cloudwego/thriftgo/generator/golang/extension/unknown" + + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // UnknownField is used to describe an unknown field. diff --git a/pkg/protocol/bthrift/utils.go b/pkg/protocol/bthrift/utils.go new file mode 100644 index 0000000000..2fd5b8f527 --- /dev/null +++ b/pkg/protocol/bthrift/utils.go @@ -0,0 +1,38 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bthrift + +import ( + "reflect" + "unsafe" +) + +// from utils.SliceByteToString for fixing cyclic import +func sliceByteToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// from utils.StringToSliceByte for fixing cyclic import +func stringToSliceByte(s string) []byte { + p := unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data) + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + hdr.Data = uintptr(p) + hdr.Cap = len(s) + hdr.Len = len(s) + return b +} diff --git a/pkg/remote/codec/thrift/binary_protocol.go b/pkg/remote/codec/thrift/binary_protocol.go index fadffbbdb5..eae7608faa 100644 --- a/pkg/remote/codec/thrift/binary_protocol.go +++ b/pkg/remote/codec/thrift/binary_protocol.go @@ -22,8 +22,7 @@ import ( "math" "sync" - "github.com/apache/thrift/lib/go/thrift" - + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) diff --git a/pkg/remote/codec/thrift/binary_protocol_test.go b/pkg/remote/codec/thrift/binary_protocol_test.go index f4a6845ce5..3b57615903 100644 --- a/pkg/remote/codec/thrift/binary_protocol_test.go +++ b/pkg/remote/codec/thrift/binary_protocol_test.go @@ -21,9 +21,8 @@ import ( "encoding/binary" "testing" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" ) diff --git a/pkg/remote/codec/thrift/skip_decoder.go b/pkg/remote/codec/thrift/skip_decoder.go index d209ae18a6..13bb045c2b 100644 --- a/pkg/remote/codec/thrift/skip_decoder.go +++ b/pkg/remote/codec/thrift/skip_decoder.go @@ -21,8 +21,7 @@ import ( "errors" "fmt" - "github.com/apache/thrift/lib/go/thrift" - + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) diff --git a/pkg/remote/codec/thrift/skip_decoder_test.go b/pkg/remote/codec/thrift/skip_decoder_test.go index 3da2bf6388..0b8b8dc2ca 100644 --- a/pkg/remote/codec/thrift/skip_decoder_test.go +++ b/pkg/remote/codec/thrift/skip_decoder_test.go @@ -19,9 +19,8 @@ package thrift import ( "testing" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" ) diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index c6d2ca163e..3d6d5c44ee 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -21,9 +21,8 @@ import ( "errors" "fmt" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index b350d6191f..c17340b22c 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -20,10 +20,10 @@ import ( "context" "fmt" - "github.com/apache/thrift/lib/go/thrift" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) @@ -68,6 +68,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ return nil, err } + // TODO: Remove the fallback code after skip decoder is stable // fallback to old thrift way (slow) transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize) tProt := thrift.NewTBinaryProtocol(transport, true, true) diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index 8ceebb148a..2426e84e32 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -22,10 +22,9 @@ import ( "strings" "testing" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/internal/mocks/thrift/fast" "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" ) diff --git a/pkg/remote/codec/thrift/thrift_frugal.go b/pkg/remote/codec/thrift/thrift_frugal.go index 235f2d30c1..042cfd2467 100644 --- a/pkg/remote/codec/thrift/thrift_frugal.go +++ b/pkg/remote/codec/thrift/thrift_frugal.go @@ -27,11 +27,11 @@ import ( "fmt" "reflect" - "github.com/apache/thrift/lib/go/thrift" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/frugal" "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 14372ce6e0..23554ea39a 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -21,12 +21,12 @@ import ( "errors" "testing" - "github.com/apache/thrift/lib/go/thrift" "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/internal/mocks" mt "github.com/cloudwego/kitex/internal/mocks/thrift/fast" "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/rpcinfo" diff --git a/pkg/remote/trans/netpollmux/control_frame.go b/pkg/remote/trans/netpollmux/control_frame.go index dcdf73c7b4..9c50813fbe 100644 --- a/pkg/remote/trans/netpollmux/control_frame.go +++ b/pkg/remote/trans/netpollmux/control_frame.go @@ -25,7 +25,7 @@ package netpollmux import ( "fmt" - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type ControlFrame struct{} diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index 27741dab12..8741aea9dc 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -17,12 +17,11 @@ package utils import ( - "bytes" - "context" "errors" "fmt" - "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // ThriftMessageCodec is used to codec thrift messages. @@ -33,6 +32,7 @@ type ThriftMessageCodec struct { // NewThriftMessageCodec creates a new ThriftMessageCodec. func NewThriftMessageCodec() *ThriftMessageCodec { + // TODO: use remote.ByteBuffer & remote/codec/thrift.BinaryProtocol transport := thrift.NewTMemoryBufferLen(1024) tProt := thrift.NewTBinaryProtocol(transport, true, true) @@ -119,38 +119,33 @@ func (t *ThriftMessageCodec) Deserialize(msg thrift.TStruct, b []byte) (err erro // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. func MarshalError(method string, err error) []byte { - e := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, err.Error()) - var buf bytes.Buffer - trans := thrift.NewStreamTransportRW(&buf) - proto := thrift.NewTBinaryProtocol(trans, true, true) - if err := proto.WriteMessageBegin(method, thrift.EXCEPTION, 0); err != nil { - return nil - } - if err := e.Write(proto); err != nil { - return nil - } - if err := proto.WriteMessageEnd(); err != nil { - return nil - } - if err := proto.Flush(context.Background()); err != nil { - return nil - } - return buf.Bytes() + ex := bthrift.NewApplicationException(thrift.INTERNAL_ERROR, err.Error()) + n := bthrift.Binary.MessageBeginLength(method, thrift.EXCEPTION, 0) + n += ex.BLength() + b := make([]byte, n) + // Write message header + off := bthrift.Binary.WriteMessageBegin(b, method, thrift.EXCEPTION, 0) + // Write Ex body + off += ex.FastWrite(b[off:]) + return b[:off] } // UnmarshalError decode binary and return error message func UnmarshalError(b []byte) error { - trans := thrift.NewStreamTransportR(bytes.NewReader(b)) - proto := thrift.NewTBinaryProtocolTransport(trans) - if _, _, _, err := proto.ReadMessageBegin(); err != nil { - return fmt.Errorf("read message begin error: %w", err) - } - e := thrift.NewTApplicationException(0, "") - if err := e.Read(proto); err != nil { - return fmt.Errorf("read exception error: %w", err) - } - if err := proto.ReadMessageEnd(); err != nil { - return fmt.Errorf("read message end error: %w", err) - } - return e + // Read message header + _, tp, _, l, err := bthrift.Binary.ReadMessageBegin(b) + if err != nil { + return err + } + if tp != thrift.EXCEPTION { + return fmt.Errorf("expects thrift.EXCEPTION, found: %d", tp) + } + // Read Ex body + off := l + ex := bthrift.NewApplicationException(thrift.INTERNAL_ERROR, "") + if _, err := ex.FastRead(b[off:]); err != nil { + return err + } + // XXX: for compatibility, consider to remove it in the future + return thrift.NewTApplicationException(ex.TypeID(), ex.Msg()) } diff --git a/pkg/utils/thrift_test.go b/pkg/utils/thrift_test.go index 4f42439d08..d0139f75b5 100644 --- a/pkg/utils/thrift_test.go +++ b/pkg/utils/thrift_test.go @@ -20,10 +20,9 @@ import ( "errors" "testing" - "github.com/apache/thrift/lib/go/thrift" - mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) func TestRPCCodec(t *testing.T) { @@ -113,5 +112,6 @@ func TestSerializer(t *testing.T) { func TestException(t *testing.T) { errMsg := "my error" b := MarshalError("some method", errors.New(errMsg)) - test.Assert(t, UnmarshalError(b).Error() == errMsg) + err := UnmarshalError(b) + test.Assert(t, err.Error() == errMsg, err) } From ff676c6acb2bcc5fefc4df829903312ed0d93ec1 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:40:58 +0800 Subject: [PATCH 02/70] refactor(generic): refactor existing generic to have new ServiceInfo which has the generic's reader and writer info directly (#1408) --- client/genericclient/client.go | 14 +- internal/mocks/generic/thrift.go | 4 +- pkg/generic/generic.go | 103 ++++--- pkg/generic/generic_service.go | 101 +++++-- pkg/generic/generic_service_test.go | 25 +- pkg/generic/generic_test.go | 22 +- pkg/generic/httppbthrift_codec.go | 75 ++++-- pkg/generic/httppbthrift_codec_test.go | 39 +++ pkg/generic/httpthrift_codec.go | 73 +++-- pkg/generic/httpthrift_codec_test.go | 96 ++----- pkg/generic/jsonpb_codec.go | 80 +++--- pkg/generic/jsonpb_codec_test.go | 53 +--- pkg/generic/jsonthrift_codec.go | 100 ++++--- pkg/generic/jsonthrift_codec_test.go | 251 ++++-------------- pkg/generic/mapthrift_codec.go | 102 ++++--- pkg/generic/mapthrift_codec_test.go | 163 ++---------- pkg/generic/proto/json.go | 54 ++-- pkg/generic/proto/json_test.go | 12 +- pkg/generic/proto/protobuf.go | 4 +- pkg/generic/thrift/http.go | 31 +-- pkg/generic/thrift/http_fallback.go | 2 +- pkg/generic/thrift/http_go116plus_amd64.go | 15 +- pkg/generic/thrift/http_pb.go | 19 +- pkg/generic/thrift/json.go | 115 ++++---- pkg/generic/thrift/json_fallback.go | 4 +- pkg/generic/thrift/json_go116plus_amd64.go | 35 ++- pkg/generic/thrift/struct.go | 67 ++--- pkg/generic/thrift/thrift.go | 4 +- pkg/remote/codec/protobuf/protobuf.go | 10 +- pkg/remote/codec/thrift/thrift.go | 16 +- pkg/remote/codec/thrift/thrift_data.go | 20 +- pkg/remote/codec/thrift/thrift_data_test.go | 14 +- pkg/remote/codec/thrift/thrift_frugal_test.go | 4 +- pkg/remote/codec/thrift/thrift_test.go | 18 +- pkg/remote/message.go | 2 +- pkg/serviceinfo/serviceinfo.go | 2 +- server/genericserver/server.go | 2 +- server/genericserver/server_test.go | 2 +- 38 files changed, 804 insertions(+), 949 deletions(-) diff --git a/client/genericclient/client.go b/client/genericclient/client.go index a5f9fd12d2..c40d69ebaa 100644 --- a/client/genericclient/client.go +++ b/client/genericclient/client.go @@ -31,7 +31,7 @@ var _ Client = &genericServiceClient{} // NewClient create a generic client func NewClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { - svcInfo := generic.ServiceInfo(g.PayloadCodecType()) + svcInfo := generic.ServiceInfoWithCodec(g) return NewClientWithServiceInfo(destService, g, svcInfo, opts...) } @@ -47,6 +47,7 @@ func NewClientWithServiceInfo(destService string, g generic.Generic, svcInfo *se return nil, err } cli := &genericServiceClient{ + svcInfo: svcInfo, kClient: kc, g: g, } @@ -86,24 +87,27 @@ type Client interface { } type genericServiceClient struct { + svcInfo *serviceinfo.ServiceInfo kClient client.Client g generic.Generic } func (gc *genericServiceClient) GenericCall(ctx context.Context, method string, request interface{}, callOptions ...callopt.Option) (response interface{}, err error) { ctx = client.NewCtxWithCallOptions(ctx, callOptions) - var _args generic.Args + _args := gc.svcInfo.MethodInfo(method).NewArgs().(*generic.Args) _args.Method = method _args.Request = request + mt, err := gc.g.GetMethod(request, method) if err != nil { return nil, err } if mt.Oneway { - return nil, gc.kClient.Call(ctx, mt.Name, &_args, nil) + return nil, gc.kClient.Call(ctx, mt.Name, _args, nil) } - var _result generic.Result - if err = gc.kClient.Call(ctx, mt.Name, &_args, &_result); err != nil { + + _result := gc.svcInfo.MethodInfo(method).NewResult().(*generic.Result) + if err = gc.kClient.Call(ctx, mt.Name, _args, _result); err != nil { return } return _result.GetSuccess(), nil diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index 8aaa0ca722..05da3f79c0 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -54,7 +54,7 @@ func (m *MockMessageReader) EXPECT() *MockMessageReaderMockRecorder { } // Read mocks base method. -func (m *MockMessageReader) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (m *MockMessageReader) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Read", ctx, method, in) ret0, _ := ret[0].(interface{}) @@ -92,7 +92,7 @@ func (m *MockMessageWriter) EXPECT() *MockMessageWriterMockRecorder { } // Write mocks base method. -func (m *MockMessageWriter) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *thrift0.Base) error { +func (m *MockMessageWriter) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *thrift0.Base) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", ctx, out, msg, requestBase) ret0, _ := ret[0].(error) diff --git a/pkg/generic/generic.go b/pkg/generic/generic.go index 3f86fb7643..93eb8869e4 100644 --- a/pkg/generic/generic.go +++ b/pkg/generic/generic.go @@ -32,13 +32,19 @@ import ( type Generic interface { Closer // PayloadCodec return codec implement + // this is used for generic which does not need IDL PayloadCodec() remote.PayloadCodec // PayloadCodecType return the type of codec PayloadCodecType() serviceinfo.PayloadCodec // RawThriftBinaryGeneric must be framed Framed() bool - // GetMethod to get method name if need + // GetMethod is to get method name if needed GetMethod(req interface{}, method string) (*Method, error) + // IDLServiceName returns idl service name + IDLServiceName() string + // MessageReaderWriter returns reader and writer + // this is used for generic which needs IDL + MessageReaderWriter() interface{} } // Method information @@ -64,22 +70,14 @@ func BinaryThriftGeneric() Generic { // // SetBinaryWithByteSlice(g, true) func MapThriftGeneric(p DescriptorProvider) (Generic, error) { - codec, err := newMapThriftCodec(p, thriftCodec) - if err != nil { - return nil, err - } return &mapThriftGeneric{ - codec: codec, + codec: newMapThriftCodec(p), }, nil } func MapThriftGenericForJSON(p DescriptorProvider) (Generic, error) { - codec, err := newMapThriftCodecForJSON(p, thriftCodec) - if err != nil { - return nil, err - } return &mapThriftGeneric{ - codec: codec, + codec: newMapThriftCodecForJSON(p), }, nil } @@ -92,20 +90,12 @@ func MapThriftGenericForJSON(p DescriptorProvider) (Generic, error) { func HTTPThriftGeneric(p DescriptorProvider, opts ...Option) (Generic, error) { gOpts := &Options{dynamicgoConvOpts: DefaultHTTPDynamicGoConvOpts} gOpts.apply(opts) - codec, err := newHTTPThriftCodec(p, thriftCodec, gOpts) - if err != nil { - return nil, err - } - return &httpThriftGeneric{codec: codec}, nil + return &httpThriftGeneric{codec: newHTTPThriftCodec(p, gOpts)}, nil } func HTTPPbThriftGeneric(p DescriptorProvider, pbp PbDescriptorProvider) (Generic, error) { - codec, err := newHTTPPbThriftCodec(p, pbp, thriftCodec) - if err != nil { - return nil, err - } return &httpPbThriftGeneric{ - codec: codec, + codec: newHTTPPbThriftCodec(p, pbp), }, nil } @@ -118,11 +108,7 @@ func HTTPPbThriftGeneric(p DescriptorProvider, pbp PbDescriptorProvider) (Generi func JSONThriftGeneric(p DescriptorProvider, opts ...Option) (Generic, error) { gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} gOpts.apply(opts) - codec, err := newJsonThriftCodec(p, thriftCodec, gOpts) - if err != nil { - return nil, err - } - return &jsonThriftGeneric{codec: codec}, nil + return &jsonThriftGeneric{codec: newJsonThriftCodec(p, gOpts)}, nil } // JSONPbGeneric json mapping generic. @@ -130,12 +116,7 @@ func JSONThriftGeneric(p DescriptorProvider, opts ...Option) (Generic, error) { func JSONPbGeneric(p PbDescriptorProviderDynamicGo, opts ...Option) (Generic, error) { gOpts := &Options{dynamicgoConvOpts: conv.Options{}} gOpts.apply(opts) - - codec, err := newJsonPbCodec(p, pbCodec, gOpts) - if err != nil { - return nil, err - } - return &jsonPbGeneric{codec: codec}, nil + return &jsonPbGeneric{codec: newJsonPbCodec(p, gOpts)}, nil } // SetBinaryWithBase64 enable/disable Base64 codec for binary field. @@ -243,6 +224,14 @@ func (g *binaryThriftGeneric) Close() error { return nil } +func (g *binaryThriftGeneric) IDLServiceName() string { + return "" +} + +func (g *binaryThriftGeneric) MessageReaderWriter() interface{} { + return nil +} + type mapThriftGeneric struct { codec *mapThriftCodec } @@ -256,7 +245,7 @@ func (g *mapThriftGeneric) PayloadCodecType() serviceinfo.PayloadCodec { } func (g *mapThriftGeneric) PayloadCodec() remote.PayloadCodec { - return g.codec + return nil } func (g *mapThriftGeneric) GetMethod(req interface{}, method string) (*Method, error) { @@ -267,6 +256,14 @@ func (g *mapThriftGeneric) Close() error { return g.codec.Close() } +func (g *mapThriftGeneric) IDLServiceName() string { + return g.codec.svcName +} + +func (g *mapThriftGeneric) MessageReaderWriter() interface{} { + return g.codec.getMessageReaderWriter() +} + type jsonThriftGeneric struct { codec *jsonThriftCodec } @@ -280,7 +277,7 @@ func (g *jsonThriftGeneric) PayloadCodecType() serviceinfo.PayloadCodec { } func (g *jsonThriftGeneric) PayloadCodec() remote.PayloadCodec { - return g.codec + return nil } func (g *jsonThriftGeneric) GetMethod(req interface{}, method string) (*Method, error) { @@ -291,6 +288,14 @@ func (g *jsonThriftGeneric) Close() error { return g.codec.Close() } +func (g *jsonThriftGeneric) IDLServiceName() string { + return g.codec.svcName +} + +func (g *jsonThriftGeneric) MessageReaderWriter() interface{} { + return g.codec.getMessageReaderWriter() +} + type jsonPbGeneric struct { codec *jsonPbCodec } @@ -304,7 +309,7 @@ func (g *jsonPbGeneric) PayloadCodecType() serviceinfo.PayloadCodec { } func (g *jsonPbGeneric) PayloadCodec() remote.PayloadCodec { - return g.codec + return nil } func (g *jsonPbGeneric) GetMethod(req interface{}, method string) (*Method, error) { @@ -315,6 +320,14 @@ func (g *jsonPbGeneric) Close() error { return g.codec.Close() } +func (g *jsonPbGeneric) IDLServiceName() string { + return g.codec.svcName +} + +func (g *jsonPbGeneric) MessageReaderWriter() interface{} { + return g.codec.getMessageReaderWriter() +} + type httpThriftGeneric struct { codec *httpThriftCodec } @@ -328,7 +341,7 @@ func (g *httpThriftGeneric) PayloadCodecType() serviceinfo.PayloadCodec { } func (g *httpThriftGeneric) PayloadCodec() remote.PayloadCodec { - return g.codec + return nil } func (g *httpThriftGeneric) GetMethod(req interface{}, method string) (*Method, error) { @@ -339,6 +352,14 @@ func (g *httpThriftGeneric) Close() error { return g.codec.Close() } +func (g *httpThriftGeneric) IDLServiceName() string { + return g.codec.svcName +} + +func (g *httpThriftGeneric) MessageReaderWriter() interface{} { + return g.codec.getMessageReaderWriter() +} + type httpPbThriftGeneric struct { codec *httpPbThriftCodec } @@ -352,7 +373,7 @@ func (g *httpPbThriftGeneric) PayloadCodecType() serviceinfo.PayloadCodec { } func (g *httpPbThriftGeneric) PayloadCodec() remote.PayloadCodec { - return g.codec + return nil } func (g *httpPbThriftGeneric) GetMethod(req interface{}, method string) (*Method, error) { @@ -362,3 +383,11 @@ func (g *httpPbThriftGeneric) GetMethod(req interface{}, method string) (*Method func (g *httpPbThriftGeneric) Close() error { return g.codec.Close() } + +func (g *httpPbThriftGeneric) IDLServiceName() string { + return g.codec.svcName +} + +func (g *httpPbThriftGeneric) MessageReaderWriter() interface{} { + return g.codec.getMessageReaderWriter() +} diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index 5655fd3fad..b24b44a401 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -23,6 +23,7 @@ import ( gproto "github.com/cloudwego/kitex/pkg/generic/proto" gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + codecProto "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -33,19 +34,43 @@ type Service interface { GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) } +// ServiceInfoWithCodec create a generic ServiceInfo with CodecInfo +func ServiceInfoWithCodec(g Generic) *serviceinfo.ServiceInfo { + return newServiceInfo(g.PayloadCodecType(), g.MessageReaderWriter(), g.IDLServiceName()) +} + +// Deprecated: it's not used by kitex anymore. // ServiceInfo create a generic ServiceInfo func ServiceInfo(pcType serviceinfo.PayloadCodec) *serviceinfo.ServiceInfo { - return newServiceInfo(pcType) + return newServiceInfo(pcType, nil, "") } -func newServiceInfo(pcType serviceinfo.PayloadCodec) *serviceinfo.ServiceInfo { - serviceName := serviceinfo.GenericService +func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interface{}, serviceName string) *serviceinfo.ServiceInfo { handlerType := (*Service)(nil) - methods := map[string]serviceinfo.MethodInfo{ - serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(callHandler, newGenericServiceCallArgs, newGenericServiceCallResult, false), + + var methods map[string]serviceinfo.MethodInfo + var svcName string + + if messageReaderWriter == nil { + // note: binary generic cannot be used with multi-service feature + svcName = serviceinfo.GenericService + methods = map[string]serviceinfo.MethodInfo{ + serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(callHandler, newGenericServiceCallArgs, newGenericServiceCallResult, false), + } + } else { + svcName = serviceName + methods = map[string]serviceinfo.MethodInfo{ + serviceinfo.GenericMethod: serviceinfo.NewMethodInfo( + callHandler, + func() interface{} { return &Args{inner: messageReaderWriter} }, + func() interface{} { return &Result{inner: messageReaderWriter} }, + false, + ), + } } + svcInfo := &serviceinfo.ServiceInfo{ - ServiceName: serviceName, + ServiceName: svcName, HandlerType: handlerType, Methods: methods, PayloadCodec: pcType, @@ -89,10 +114,13 @@ type Args struct { var ( _ codecThrift.MessageReaderWithMethodWithContext = (*Args)(nil) - _ codecThrift.MessageWriterWithContext = (*Args)(nil) + _ codecThrift.MessageWriterWithMethodWithContext = (*Args)(nil) + _ codecProto.MessageWriterWithContext = (*Args)(nil) + _ codecProto.MessageReaderWithMethodWithContext = (*Args)(nil) _ WithCodec = (*Args)(nil) ) +// Deprecated: it's not used by kitex anymore. // SetCodec ... func (g *Args) SetCodec(inner interface{}) { g.inner = inner @@ -106,36 +134,48 @@ func (g *Args) GetOrSetBase() interface{} { } // Write ... -func (g *Args) Write(ctx context.Context, out thrift.TProtocol) error { +func (g *Args) Write(ctx context.Context, method string, out thrift.TProtocol) error { + if err, ok := g.inner.(error); ok { + return err + } if w, ok := g.inner.(gthrift.MessageWriter); ok { - return w.Write(ctx, out, g.Request, g.base) + return w.Write(ctx, out, g.Request, method, true, g.base) } return fmt.Errorf("unexpected Args writer type: %T", g.inner) } -func (g *Args) WritePb(ctx context.Context) (interface{}, error) { +func (g *Args) WritePb(ctx context.Context, method string) (interface{}, error) { + if err, ok := g.inner.(error); ok { + return nil, err + } if w, ok := g.inner.(gproto.MessageWriter); ok { - return w.Write(ctx, g.Request) + return w.Write(ctx, g.Request, method, true) } return nil, fmt.Errorf("unexpected Args writer type: %T", g.inner) } // Read ... -func (g *Args) Read(ctx context.Context, method string, in thrift.TProtocol) error { - if w, ok := g.inner.(gthrift.MessageReader); ok { +func (g *Args) Read(ctx context.Context, method string, dataLen int, in thrift.TProtocol) error { + if err, ok := g.inner.(error); ok { + return err + } + if rw, ok := g.inner.(gthrift.MessageReader); ok { g.Method = method var err error - g.Request, err = w.Read(ctx, method, in) + g.Request, err = rw.Read(ctx, method, false, dataLen, in) return err } return fmt.Errorf("unexpected Args reader type: %T", g.inner) } func (g *Args) ReadPb(ctx context.Context, method string, in []byte) error { + if err, ok := g.inner.(error); ok { + return err + } if w, ok := g.inner.(gproto.MessageReader); ok { g.Method = method var err error - g.Request, err = w.Read(ctx, method, in) + g.Request, err = w.Read(ctx, method, false, in) return err } return fmt.Errorf("unexpected Args reader type: %T", g.inner) @@ -154,44 +194,59 @@ type Result struct { var ( _ codecThrift.MessageReaderWithMethodWithContext = (*Result)(nil) - _ codecThrift.MessageWriterWithContext = (*Result)(nil) + _ codecThrift.MessageWriterWithMethodWithContext = (*Result)(nil) + _ codecProto.MessageWriterWithContext = (*Result)(nil) + _ codecProto.MessageReaderWithMethodWithContext = (*Result)(nil) _ WithCodec = (*Result)(nil) ) +// Deprecated: it's not used by kitex anymore. // SetCodec ... func (r *Result) SetCodec(inner interface{}) { r.inner = inner } // Write ... -func (r *Result) Write(ctx context.Context, out thrift.TProtocol) error { +func (r *Result) Write(ctx context.Context, method string, out thrift.TProtocol) error { + if err, ok := r.inner.(error); ok { + return err + } if w, ok := r.inner.(gthrift.MessageWriter); ok { - return w.Write(ctx, out, r.Success, nil) + return w.Write(ctx, out, r.Success, method, false, nil) } return fmt.Errorf("unexpected Result writer type: %T", r.inner) } -func (r *Result) WritePb(ctx context.Context) (interface{}, error) { +func (r *Result) WritePb(ctx context.Context, method string) (interface{}, error) { + if err, ok := r.inner.(error); ok { + return nil, err + } if w, ok := r.inner.(gproto.MessageWriter); ok { - return w.Write(ctx, r.Success) + return w.Write(ctx, r.Success, method, false) } return nil, fmt.Errorf("unexpected Result writer type: %T", r.inner) } // Read ... -func (r *Result) Read(ctx context.Context, method string, in thrift.TProtocol) error { +func (r *Result) Read(ctx context.Context, method string, dataLen int, in thrift.TProtocol) error { + if err, ok := r.inner.(error); ok { + return err + } if w, ok := r.inner.(gthrift.MessageReader); ok { var err error - r.Success, err = w.Read(ctx, method, in) + r.Success, err = w.Read(ctx, method, true, dataLen, in) return err } return fmt.Errorf("unexpected Result reader type: %T", r.inner) } func (r *Result) ReadPb(ctx context.Context, method string, in []byte) error { + if err, ok := r.inner.(error); ok { + return err + } if w, ok := r.inner.(gproto.MessageReader); ok { var err error - r.Success, err = w.Read(ctx, method, in) + r.Success, err = w.Read(ctx, method, true, in) return err } return fmt.Errorf("unexpected Result reader type: %T", r.inner) diff --git a/pkg/generic/generic_service_test.go b/pkg/generic/generic_service_test.go index 95eba59baf..6b6b288c30 100644 --- a/pkg/generic/generic_service_test.go +++ b/pkg/generic/generic_service_test.go @@ -56,21 +56,21 @@ func TestGenericService(t *testing.T) { test.Assert(t, base != nil) a.SetCodec(struct{}{}) // write not ok - err := a.Write(ctx, tProto) + err := a.Write(ctx, method, tProto) test.Assert(t, err.Error() == "unexpected Args writer type: struct {}") // Write expect argWriteInner.EXPECT().Write(ctx, tProto, a.Request, a.GetOrSetBase()).Return(nil) a.SetCodec(argWriteInner) // write ok - err = a.Write(ctx, tProto) - test.Assert(t, err == nil) + err = a.Write(ctx, method, tProto) + test.Assert(t, err == nil, err) // read not ok - err = a.Read(ctx, method, tProto) + err = a.Read(ctx, method, 0, tProto) test.Assert(t, strings.Contains(err.Error(), "unexpected Args reader type")) // read ok a.SetCodec(rInner) - err = a.Read(ctx, method, tProto) + err = a.Read(ctx, method, 0, tProto) test.Assert(t, err == nil) // Result... @@ -79,20 +79,20 @@ func TestGenericService(t *testing.T) { test.Assert(t, ok == true) // write not ok - err = r.Write(ctx, tProto) + err = r.Write(ctx, method, tProto) test.Assert(t, err.Error() == "unexpected Result writer type: ") // Write expect resultWriteInner.EXPECT().Write(ctx, tProto, r.Success, (*gthrift.Base)(nil)).Return(nil).AnyTimes() r.SetCodec(resultWriteInner) // write ok - err = r.Write(ctx, tProto) + err = r.Write(ctx, method, tProto) test.Assert(t, err == nil) // read not ok - err = r.Read(ctx, method, tProto) + err = r.Read(ctx, method, 0, tProto) test.Assert(t, strings.Contains(err.Error(), "unexpected Result reader type")) // read ok r.SetCodec(rInner) - err = r.Read(ctx, method, tProto) + err = r.Read(ctx, method, 0, tProto) test.Assert(t, err == nil) r.SetSuccess(nil) @@ -124,6 +124,13 @@ func TestGenericService(t *testing.T) { func TestServiceInfo(t *testing.T) { s := ServiceInfo(serviceinfo.Thrift) test.Assert(t, s.ServiceName == "$GenericService") + + p, err := NewThriftFileProvider("./json_test/idl/mock.thrift") + test.Assert(t, err == nil) + g, err := JSONThriftGeneric(p) + test.Assert(t, err == nil) + s = ServiceInfoWithCodec(g) + test.Assert(t, s.ServiceName == "Mock") } func TestArgsResult(t *testing.T) { diff --git a/pkg/generic/generic_test.go b/pkg/generic/generic_test.go index 7ae95e07fb..816199906c 100644 --- a/pkg/generic/generic_test.go +++ b/pkg/generic/generic_test.go @@ -35,6 +35,7 @@ func TestBinaryThriftGeneric(t *testing.T) { test.Assert(t, g.Framed() == true) test.Assert(t, g.PayloadCodec().Name() == "RawThriftBinary") test.Assert(t, g.PayloadCodecType() == serviceinfo.Thrift) + test.Assert(t, g.MessageReaderWriter() == nil) method, err := g.GetMethod(nil, "Test") test.Assert(t, err == nil) @@ -53,7 +54,8 @@ func TestMapThriftGeneric(t *testing.T) { mg, ok := g.(*mapThriftGeneric) test.Assert(t, ok) - test.Assert(t, g.PayloadCodec().Name() == "MapThrift") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "Mock") err = SetBinaryWithBase64(g, true) test.Assert(t, err == nil) @@ -83,7 +85,8 @@ func TestMapThriftGenericForJSON(t *testing.T) { mg, ok := g.(*mapThriftGeneric) test.Assert(t, ok) - test.Assert(t, g.PayloadCodec().Name() == "MapThrift") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "Mock") err = SetBinaryWithBase64(g, true) test.Assert(t, err == nil) @@ -110,7 +113,8 @@ func TestHTTPThriftGeneric(t *testing.T) { hg, ok := g.(*httpThriftGeneric) test.Assert(t, ok) - test.Assert(t, g.PayloadCodec().Name() == "HttpThrift") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "ExampleService") test.Assert(t, !hg.codec.dynamicgoEnabled) test.Assert(t, hg.codec.useRawBodyForHTTPResp) @@ -156,7 +160,8 @@ func TestHTTPThriftGenericWithDynamicGo(t *testing.T) { hg, ok := g.(*httpThriftGeneric) test.Assert(t, ok) - test.Assert(t, g.PayloadCodec().Name() == "HttpThrift") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "ExampleService") test.Assert(t, hg.codec.dynamicgoEnabled) test.Assert(t, !hg.codec.useRawBodyForHTTPResp) @@ -202,7 +207,8 @@ func TestJSONThriftGeneric(t *testing.T) { jg, ok := g.(*jsonThriftGeneric) test.Assert(t, ok) - test.Assert(t, g.PayloadCodec().Name() == "JSONThrift") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "Mock") test.Assert(t, !jg.codec.dynamicgoEnabled) test.Assert(t, jg.codec.binaryWithBase64) @@ -235,7 +241,8 @@ func TestJSONThriftGenericWithDynamicGo(t *testing.T) { jg, ok := g.(*jsonThriftGeneric) test.Assert(t, ok) - test.Assert(t, g.PayloadCodec().Name() == "JSONThrift") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "Mock") test.Assert(t, jg.codec.dynamicgoEnabled) test.Assert(t, jg.codec.binaryWithBase64) @@ -271,7 +278,8 @@ func TestJSONPbGeneric(t *testing.T) { test.Assert(t, err == nil) defer g.Close() - test.Assert(t, g.PayloadCodec().Name() == "JSONPb") + test.Assert(t, g.PayloadCodec() == nil) + test.Assert(t, g.IDLServiceName() == "Echo") test.Assert(t, g.PayloadCodecType() == serviceinfo.Protobuf) diff --git a/pkg/generic/httppbthrift_codec.go b/pkg/generic/httppbthrift_codec.go index 978acc3ffa..cde66eec43 100644 --- a/pkg/generic/httppbthrift_codec.go +++ b/pkg/generic/httppbthrift_codec.go @@ -19,7 +19,6 @@ package generic import ( "context" "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -36,22 +35,24 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) +var _ Closer = &httpPbThriftCodec{} + type httpPbThriftCodec struct { svcDsc atomic.Value // *idl pbSvcDsc atomic.Value // *pbIdl provider DescriptorProvider pbProvider PbDescriptorProvider - codec remote.PayloadCodec + svcName string } -func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider, codec remote.PayloadCodec) (*httpPbThriftCodec, error) { +func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider) *httpPbThriftCodec { svc := <-p.Provide() pbSvc := <-pbp.Provide() - c := &httpPbThriftCodec{codec: codec, provider: p, pbProvider: pbp} + c := &httpPbThriftCodec{provider: p, pbProvider: pbp, svcName: svc.Name} c.svcDsc.Store(svc) c.pbSvcDsc.Store(pbSvc) go c.update() - return c, nil + return c } func (c *httpPbThriftCodec) update() { @@ -66,6 +67,7 @@ func (c *httpPbThriftCodec) update() { return } + c.svcName = svc.Name c.svcDsc.Store(svc) c.pbSvcDsc.Store(pbSvc) } @@ -87,37 +89,17 @@ func (c *httpPbThriftCodec) getMethod(req interface{}) (*Method, error) { return &Method{function.Name, function.Oneway}, nil } -func (c *httpPbThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { +func (c *httpPbThriftCodec) getMessageReaderWriter() interface{} { svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { - return fmt.Errorf("get parser ServiceDescriptor failed") + return errors.New("get parser ServiceDescriptor failed") } pbSvcDsc, ok := c.pbSvcDsc.Load().(*desc.ServiceDescriptor) if !ok { - return fmt.Errorf("get parser PbServiceDescriptor failed") + return errors.New("get parser PbServiceDescriptor failed") } - inner := thrift.NewWriteHTTPPbRequest(svcDsc, pbSvcDsc) - msg.Data().(WithCodec).SetCodec(inner) - return c.codec.Marshal(ctx, msg, out) -} - -func (c *httpPbThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { - if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { - return err - } - svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) - if !ok { - return fmt.Errorf("get parser ServiceDescriptor failed") - } - pbSvcDsc, ok := c.pbSvcDsc.Load().(proto.ServiceDescriptor) - if !ok { - return fmt.Errorf("get parser PbServiceDescriptor failed") - } - - inner := thrift.NewReadHTTPPbResponse(svcDsc, pbSvcDsc) - msg.Data().(WithCodec).SetCodec(inner) - return c.codec.Unmarshal(ctx, msg, in) + return thrift.NewHTTPPbReaderWriter(svcDsc, pbSvcDsc) } func (c *httpPbThriftCodec) Name() string { @@ -140,6 +122,41 @@ func (c *httpPbThriftCodec) Close() error { } } +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +func (c *httpPbThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { + svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) + if !ok { + return errors.New("get parser ServiceDescriptor failed") + } + pbSvcDsc, ok := c.pbSvcDsc.Load().(*desc.ServiceDescriptor) + if !ok { + return errors.New("get parser PbServiceDescriptor failed") + } + + inner := thrift.NewWriteHTTPPbRequest(svcDsc, pbSvcDsc) + msg.Data().(WithCodec).SetCodec(inner) + return thriftCodec.Marshal(ctx, msg, out) +} + +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +func (c *httpPbThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { + return err + } + svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) + if !ok { + return errors.New("get parser ServiceDescriptor failed") + } + pbSvcDsc, ok := c.pbSvcDsc.Load().(proto.ServiceDescriptor) + if !ok { + return errors.New("get parser PbServiceDescriptor failed") + } + + inner := thrift.NewReadHTTPPbResponse(svcDsc, pbSvcDsc) + msg.Data().(WithCodec).SetCodec(inner) + return thriftCodec.Unmarshal(ctx, msg, in) +} + // FromHTTPPbRequest parse HTTPRequest from http.Request func FromHTTPPbRequest(req *http.Request) (*HTTPRequest, error) { customReq := &HTTPRequest{ diff --git a/pkg/generic/httppbthrift_codec_test.go b/pkg/generic/httppbthrift_codec_test.go index 7a36cdd594..7871fac1cf 100644 --- a/pkg/generic/httppbthrift_codec_test.go +++ b/pkg/generic/httppbthrift_codec_test.go @@ -18,11 +18,14 @@ package generic import ( "bytes" + "io" "net/http" + "os" "reflect" "testing" "github.com/cloudwego/kitex/internal/test" + gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" ) func TestFromHTTPPbRequest(t *testing.T) { @@ -34,3 +37,39 @@ func TestFromHTTPPbRequest(t *testing.T) { test.Assert(t, hreq.GetMethod() == "POST") test.Assert(t, hreq.GetPath() == "/far/boo") } + +func TestHTTPPbThriftCodec(t *testing.T) { + p, err := NewThriftFileProvider("./httppb_test/idl/echo.thrift") + test.Assert(t, err == nil) + pbIdl := "./httppb_test/idl/echo.proto" + pbf, err := os.Open(pbIdl) + test.Assert(t, err == nil) + pbContent, err := io.ReadAll(pbf) + test.Assert(t, err == nil) + pbf.Close() + pbp, err := NewPbContentProvider(pbIdl, map[string]string{pbIdl: string(pbContent)}) + test.Assert(t, err == nil) + + htc := newHTTPPbThriftCodec(p, pbp) + defer htc.Close() + test.Assert(t, htc.Name() == "HttpPbThrift") + + req, err := http.NewRequest("GET", "/Echo", bytes.NewBuffer([]byte("321"))) + test.Assert(t, err == nil) + hreq, err := FromHTTPPbRequest(req) + test.Assert(t, err == nil) + // wrong + method, err := htc.getMethod("test") + test.Assert(t, err.Error() == "req is invalid, need descriptor.HTTPRequest" && method == nil) + // right + method, err = htc.getMethod(hreq) + test.Assert(t, err == nil, err) + test.Assert(t, method.Name == "Echo") + test.Assert(t, htc.svcName == "ExampleService") + + rw := htc.getMessageReaderWriter() + _, ok := rw.(gthrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(gthrift.MessageReader) + test.Assert(t, ok) +} diff --git a/pkg/generic/httpthrift_codec.go b/pkg/generic/httpthrift_codec.go index e932ac99f8..8e40a7ba8e 100644 --- a/pkg/generic/httpthrift_codec.go +++ b/pkg/generic/httpthrift_codec.go @@ -19,7 +19,6 @@ package generic import ( "context" "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -34,10 +33,7 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var ( - _ remote.PayloadCodec = &httpThriftCodec{} - _ Closer = &httpThriftCodec{} -) +var _ Closer = &httpThriftCodec{} // HTTPRequest alias of descriptor HTTPRequest type HTTPRequest = descriptor.HTTPRequest @@ -48,17 +44,17 @@ type HTTPResponse = descriptor.HTTPResponse type httpThriftCodec struct { svcDsc atomic.Value // *idl provider DescriptorProvider - codec remote.PayloadCodec binaryWithBase64 bool convOpts conv.Options // used for dynamicgo conversion convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on dynamicgoEnabled bool useRawBodyForHTTPResp bool + svcName string } -func newHTTPThriftCodec(p DescriptorProvider, codec remote.PayloadCodec, opts *Options) (*httpThriftCodec, error) { +func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec { svc := <-p.Provide() - c := &httpThriftCodec{codec: codec, provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp} + c := &httpThriftCodec{provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp, svcName: svc.Name} if dp, ok := p.(GetProviderOption); ok && dp.Option().DynamicGoEnabled { c.dynamicgoEnabled = true @@ -70,7 +66,7 @@ func newHTTPThriftCodec(p DescriptorProvider, codec remote.PayloadCodec, opts *O } c.svcDsc.Store(svc) go c.update() - return c, nil + return c } func (c *httpThriftCodec) update() { @@ -79,10 +75,37 @@ func (c *httpThriftCodec) update() { if !ok { return } + c.svcName = svc.Name c.svcDsc.Store(svc) } } +func (c *httpThriftCodec) getMessageReaderWriter() interface{} { + svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) + if !ok { + return errors.New("get parser ServiceDescriptor failed") + } + rw := thrift.NewHTTPReaderWriter(svcDsc) + c.configureHTTPRequestWriter(rw.WriteHTTPRequest) + c.configureHTTPResponseReader(rw.ReadHTTPResponse) + return rw +} + +func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPRequest) { + writer.SetBinaryWithBase64(c.binaryWithBase64) + if c.dynamicgoEnabled { + writer.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase) + } +} + +func (c *httpThriftCodec) configureHTTPResponseReader(reader *thrift.ReadHTTPResponse) { + reader.SetBase64Binary(c.binaryWithBase64) + reader.SetUseRawBodyForHTTPResp(c.useRawBodyForHTTPResp) + if c.dynamicgoEnabled && c.useRawBodyForHTTPResp { + reader.SetDynamicGo(&c.convOptsWithThriftBase) + } +} + func (c *httpThriftCodec) getMethod(req interface{}) (*Method, error) { svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { @@ -99,50 +122,50 @@ func (c *httpThriftCodec) getMethod(req interface{}) (*Method, error) { return &Method{function.Name, function.Oneway}, nil } +func (c *httpThriftCodec) Name() string { + return "HttpThrift" +} + +func (c *httpThriftCodec) Close() error { + return c.provider.Close() +} + +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *httpThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { - return fmt.Errorf("get parser ServiceDescriptor failed") + return errors.New("get parser ServiceDescriptor failed") } inner := thrift.NewWriteHTTPRequest(svcDsc) inner.SetBinaryWithBase64(c.binaryWithBase64) if c.dynamicgoEnabled { - if err := inner.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase, msg.RPCInfo().Invocation().MethodName()); err != nil { - return err - } + inner.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase) } msg.Data().(WithCodec).SetCodec(inner) - return c.codec.Marshal(ctx, msg, out) + return thriftCodec.Marshal(ctx, msg, out) } +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *httpThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err } svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { - return fmt.Errorf("get parser ServiceDescriptor failed") + return errors.New("get parser ServiceDescriptor failed") } inner := thrift.NewReadHTTPResponse(svcDsc) inner.SetBase64Binary(c.binaryWithBase64) inner.SetUseRawBodyForHTTPResp(c.useRawBodyForHTTPResp) if c.dynamicgoEnabled && c.useRawBodyForHTTPResp && msg.PayloadLen() != 0 { - inner.SetDynamicGo(&c.convOpts, msg) + inner.SetDynamicGo(&c.convOpts) } msg.Data().(WithCodec).SetCodec(inner) - return c.codec.Unmarshal(ctx, msg, in) -} - -func (c *httpThriftCodec) Name() string { - return "HttpThrift" -} - -func (c *httpThriftCodec) Close() error { - return c.provider.Close() + return thriftCodec.Unmarshal(ctx, msg, in) } // FromHTTPRequest parse HTTPRequest from http.Request diff --git a/pkg/generic/httpthrift_codec_test.go b/pkg/generic/httpthrift_codec_test.go index 6c603fd23c..9186478224 100644 --- a/pkg/generic/httpthrift_codec_test.go +++ b/pkg/generic/httpthrift_codec_test.go @@ -18,20 +18,14 @@ package generic import ( "bytes" - "context" - "io/ioutil" "net/http" "testing" "github.com/bytedance/sonic" "github.com/cloudwego/dynamicgo/conv" - "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/transport" + gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" ) var customJson = sonic.Config{ @@ -53,8 +47,7 @@ func TestHttpThriftCodec(t *testing.T) { p, err := NewThriftFileProvider("./http_test/idl/binary_echo.thrift") test.Assert(t, err == nil) gOpts := &Options{dynamicgoConvOpts: DefaultHTTPDynamicGoConvOpts} - htc, err := newHTTPThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) + htc := newHTTPThriftCodec(p, gOpts) test.Assert(t, !htc.dynamicgoEnabled) test.Assert(t, !htc.useRawBodyForHTTPResp) test.DeepEqual(t, htc.convOpts, conv.Options{}) @@ -69,23 +62,17 @@ func TestHttpThriftCodec(t *testing.T) { // right method, err = htc.getMethod(req) test.Assert(t, err == nil && method.Name == "BinaryEcho") + test.Assert(t, htc.svcName == "ExampleService") - ctx := context.Background() - sendMsg := initHttpSendMsg() - - // Marshal side - out := remote.NewWriterBuffer(256) - err = htc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) + rw := htc.getMessageReaderWriter() + _, ok := rw.(error) + test.Assert(t, !ok) - // Unmarshal side - recvMsg := initHttpRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = htc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) + rw = htc.getMessageReaderWriter() + _, ok = rw.(gthrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(gthrift.MessageReader) + test.Assert(t, ok) } func TestHttpThriftCodecWithDynamicGo(t *testing.T) { @@ -93,8 +80,7 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) { p, err := NewThriftFileProviderWithDynamicGo("./http_test/idl/binary_echo.thrift") test.Assert(t, err == nil) gOpts := &Options{dynamicgoConvOpts: DefaultHTTPDynamicGoConvOpts, useRawBodyForHTTPResp: true} - htc, err := newHTTPThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) + htc := newHTTPThriftCodec(p, gOpts) test.Assert(t, htc.dynamicgoEnabled) test.Assert(t, htc.useRawBodyForHTTPResp) test.DeepEqual(t, htc.convOpts, DefaultHTTPDynamicGoConvOpts) @@ -111,59 +97,13 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) { // right method, err = htc.getMethod(req) test.Assert(t, err == nil && method.Name == "BinaryEcho") + test.Assert(t, htc.svcName == "ExampleService") - ctx := context.Background() - sendMsg := initHttpSendMsg() - - // Marshal side - out := remote.NewWriterBuffer(256) - err = htc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) - - // Unmarshal side - recvMsg := initHttpRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = htc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) -} - -func initHttpSendMsg() remote.Message { - stdReq := getStdHttpRequest() - b, err := stdReq.GetBody() - if err != nil { - panic(err) - } - rawBody, err := ioutil.ReadAll(b) - if err != nil { - panic(err) - } - req := &Args{ - Request: &descriptor.HTTPRequest{ - Request: stdReq, - RawBody: rawBody, - }, - Method: "BinaryEcho", - } - svcInfo := mocks.ServiceInfo() - ink := rpcinfo.NewInvocation("", "BinaryEcho") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) - msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec)) - return msg -} - -func initHttpRecvMsg() remote.Message { - req := &Args{ - Request: "Test", - Method: "BinaryEcho", - } - ink := rpcinfo.NewInvocation("", "BinaryEcho") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server) - return msg + rw := htc.getMessageReaderWriter() + _, ok := rw.(gthrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(gthrift.MessageReader) + test.Assert(t, ok) } func getStdHttpRequest() *http.Request { diff --git a/pkg/generic/jsonpb_codec.go b/pkg/generic/jsonpb_codec.go index 44fed0db4f..4914ebba03 100644 --- a/pkg/generic/jsonpb_codec.go +++ b/pkg/generic/jsonpb_codec.go @@ -18,6 +18,7 @@ package generic import ( "context" + "errors" "fmt" "sync/atomic" @@ -31,29 +32,26 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var ( - _ remote.PayloadCodec = &jsonPbCodec{} - _ Closer = &jsonPbCodec{} -) +var _ Closer = &jsonPbCodec{} type jsonPbCodec struct { svcDsc atomic.Value // *idl provider PbDescriptorProviderDynamicGo - codec remote.PayloadCodec opts *Options convOpts conv.Options // used for dynamicgo conversion dynamicgoEnabled bool // currently set to true by default + svcName string } -func newJsonPbCodec(p PbDescriptorProviderDynamicGo, codec remote.PayloadCodec, opts *Options) (*jsonPbCodec, error) { +func newJsonPbCodec(p PbDescriptorProviderDynamicGo, opts *Options) *jsonPbCodec { svc := <-p.Provide() - c := &jsonPbCodec{codec: codec, provider: p, opts: opts, dynamicgoEnabled: true} + c := &jsonPbCodec{provider: p, opts: opts, dynamicgoEnabled: true, svcName: svc.Name()} convOpts := opts.dynamicgoConvOpts c.convOpts = convOpts c.svcDsc.Store(svc) go c.update() - return c, nil + return c } func (c *jsonPbCodec) update() { @@ -62,31 +60,57 @@ func (c *jsonPbCodec) update() { if !ok { return } + c.svcName = svc.Name() c.svcDsc.Store(svc) } } +func (c *jsonPbCodec) getMessageReaderWriter() interface{} { + pbSvc, ok := c.svcDsc.Load().(*dproto.ServiceDescriptor) + if !ok { + return errors.New("get parser dynamicgo ServiceDescriptor failed") + } + + return proto.NewJsonReaderWriter(pbSvc, &c.convOpts) +} + +func (c *jsonPbCodec) getMethod(req interface{}, method string) (*Method, error) { + fnSvc := c.svcDsc.Load().(*dproto.ServiceDescriptor).LookupMethodByName(method) + if fnSvc == nil { + return nil, fmt.Errorf("missing method: %s in service", method) + } + + // protobuf does not have oneway methods + return &Method{method, false}, nil +} + +func (c *jsonPbCodec) Name() string { + return "JSONPb" +} + +func (c *jsonPbCodec) Close() error { + return c.provider.Close() +} + +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *jsonPbCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { method := msg.RPCInfo().Invocation().MethodName() if method == "" { return perrors.NewProtocolErrorWithMsg("empty methodName in protobuf Marshal") } if msg.MessageType() == remote.Exception { - return c.codec.Marshal(ctx, msg, out) + return pbCodec.Marshal(ctx, msg, out) } pbSvc := c.svcDsc.Load().(*dproto.ServiceDescriptor) - wm, err := proto.NewWriteJSON(pbSvc, method, msg.RPCRole() == remote.Client, &c.convOpts) - if err != nil { - return err - } - + wm := proto.NewWriteJSON(pbSvc, &c.convOpts) msg.Data().(WithCodec).SetCodec(wm) - return c.codec.Marshal(ctx, msg, out) + return pbCodec.Marshal(ctx, msg, out) } +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *jsonPbCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err @@ -94,30 +118,8 @@ func (c *jsonPbCodec) Unmarshal(ctx context.Context, msg remote.Message, in remo pbSvc := c.svcDsc.Load().(*dproto.ServiceDescriptor) - wm, err := proto.NewReadJSON(pbSvc, msg.RPCRole() == remote.Client, &c.convOpts) - if err != nil { - return err - } - + wm := proto.NewReadJSON(pbSvc, &c.convOpts) msg.Data().(WithCodec).SetCodec(wm) - return c.codec.Unmarshal(ctx, msg, in) -} - -func (c *jsonPbCodec) getMethod(req interface{}, method string) (*Method, error) { - fnSvc := c.svcDsc.Load().(*dproto.ServiceDescriptor).LookupMethodByName(method) - if fnSvc == nil { - return nil, fmt.Errorf("missing method: %s in service", method) - } - - // protobuf does not have oneway methods - return &Method{method, false}, nil -} - -func (c *jsonPbCodec) Name() string { - return "JSONPb" -} - -func (c *jsonPbCodec) Close() error { - return c.provider.Close() + return pbCodec.Unmarshal(ctx, msg, in) } diff --git a/pkg/generic/jsonpb_codec_test.go b/pkg/generic/jsonpb_codec_test.go index cc11e19c08..a1b823508d 100644 --- a/pkg/generic/jsonpb_codec_test.go +++ b/pkg/generic/jsonpb_codec_test.go @@ -23,11 +23,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" dproto "github.com/cloudwego/dynamicgo/proto" - "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/transport" + gproto "github.com/cloudwego/kitex/pkg/generic/proto" ) var echoIDLPath = "./jsonpb_test/idl/echo.proto" @@ -42,54 +39,18 @@ func TestJsonPbCodec(t *testing.T) { p, err := NewPbFileProviderWithDynamicGo(echoIDLPath, context.Background(), opts) test.Assert(t, err == nil) - jpc, err := newJsonPbCodec(p, pbCodec, gOpts) - test.Assert(t, err == nil) - + jpc := newJsonPbCodec(p, gOpts) defer jpc.Close() test.Assert(t, jpc.Name() == "JSONPb") method, err := jpc.getMethod(nil, "Echo") test.Assert(t, err == nil) test.Assert(t, method.Name == "Echo") + test.Assert(t, jpc.svcName == "Echo") - ctx := context.Background() - sendMsg := initJsonPbSendMsg(transport.TTHeaderFramed) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = jpc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) - - // UnMarshal side - recvMsg := initJsonPbRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = jpc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) - args, ok := recvMsg.Data().(*Args) + rw := jpc.getMessageReaderWriter() + _, ok := rw.(gproto.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(gproto.MessageReader) test.Assert(t, ok) - test.Assert(t, args.Request == `{"message":"hello world!"}`) -} - -func initJsonPbSendMsg(tp transport.Protocol) remote.Message { - req := &Args{ - Request: `{"message":"hello world!"}`, - Method: "Echo", - } - svcInfo := mocks.ServiceInfo() - ink := rpcinfo.NewInvocation("", "Echo") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) - msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) - return msg -} - -func initJsonPbRecvMsg() remote.Message { - resp := &Args{} - ink := rpcinfo.NewInvocation("", "Echo") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(resp, mocks.ServiceInfo(), ri, remote.Call, remote.Server) - return msg } diff --git a/pkg/generic/jsonthrift_codec.go b/pkg/generic/jsonthrift_codec.go index 214cd0d0a6..d418e9844c 100644 --- a/pkg/generic/jsonthrift_codec.go +++ b/pkg/generic/jsonthrift_codec.go @@ -18,6 +18,7 @@ package generic import ( "context" + "errors" "sync/atomic" "github.com/cloudwego/dynamicgo/conv" @@ -30,29 +31,22 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var ( - _ remote.PayloadCodec = &jsonThriftCodec{} - _ Closer = &jsonThriftCodec{} -) - -// JSONRequest alias of string -type JSONRequest = string +var _ Closer = &jsonThriftCodec{} type jsonThriftCodec struct { svcDsc atomic.Value // *idl provider DescriptorProvider - codec remote.PayloadCodec binaryWithBase64 bool - opts *Options + dynamicgoEnabled bool convOpts conv.Options // used for dynamicgo conversion convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on convOptsWithException conv.Options // used for dynamicgo conversion with ConvertException turned on - dynamicgoEnabled bool + svcName string } -func newJsonThriftCodec(p DescriptorProvider, codec remote.PayloadCodec, opts *Options) (*jsonThriftCodec, error) { +func newJsonThriftCodec(p DescriptorProvider, opts *Options) *jsonThriftCodec { svc := <-p.Provide() - c := &jsonThriftCodec{codec: codec, provider: p, binaryWithBase64: true, opts: opts, dynamicgoEnabled: false} + c := &jsonThriftCodec{provider: p, binaryWithBase64: true, dynamicgoEnabled: false, svcName: svc.Name} if dp, ok := p.(GetProviderOption); ok && dp.Option().DynamicGoEnabled { c.dynamicgoEnabled = true @@ -69,7 +63,7 @@ func newJsonThriftCodec(p DescriptorProvider, codec remote.PayloadCodec, opts *O } c.svcDsc.Store(svc) go c.update() - return c, nil + return c } func (c *jsonThriftCodec) update() { @@ -78,38 +72,78 @@ func (c *jsonThriftCodec) update() { if !ok { return } + c.svcName = svc.Name c.svcDsc.Store(svc) } } +func (c *jsonThriftCodec) getMessageReaderWriter() interface{} { + svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) + if !ok { + return errors.New("get parser ServiceDescriptor failed") + } + + rw := thrift.NewJsonReaderWriter(svcDsc) + c.configureJSONWriter(rw.WriteJSON) + c.configureJSONReader(rw.ReadJSON) + return rw +} + +func (c *jsonThriftCodec) configureJSONWriter(writer *thrift.WriteJSON) { + writer.SetBase64Binary(c.binaryWithBase64) + if c.dynamicgoEnabled { + writer.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase) + } +} + +func (c *jsonThriftCodec) configureJSONReader(reader *thrift.ReadJSON) { + reader.SetBinaryWithBase64(c.binaryWithBase64) + if c.dynamicgoEnabled { + reader.SetDynamicGo(&c.convOpts, &c.convOptsWithException) + } +} + +func (c *jsonThriftCodec) getMethod(req interface{}, method string) (*Method, error) { + fnSvc, err := c.svcDsc.Load().(*descriptor.ServiceDescriptor).LookupFunctionByMethod(method) + if err != nil { + return nil, err + } + return &Method{method, fnSvc.Oneway}, nil +} + +func (c *jsonThriftCodec) Name() string { + return "JSONThrift" +} + +func (c *jsonThriftCodec) Close() error { + return c.provider.Close() +} + +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *jsonThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { method := msg.RPCInfo().Invocation().MethodName() if method == "" { return perrors.NewProtocolErrorWithMsg("empty methodName in thrift Marshal") } if msg.MessageType() == remote.Exception { - return c.codec.Marshal(ctx, msg, out) + return thriftCodec.Marshal(ctx, msg, out) } svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { return perrors.NewProtocolErrorWithMsg("get parser ServiceDescriptor failed") } - wm, err := thrift.NewWriteJSON(svcDsc, method, msg.RPCRole() == remote.Client) - if err != nil { - return err - } + wm := thrift.NewWriteJSON(svcDsc) wm.SetBase64Binary(c.binaryWithBase64) if c.dynamicgoEnabled { - if err = wm.SetDynamicGo(svcDsc, method, &c.convOpts, &c.convOptsWithThriftBase); err != nil { - return err - } + wm.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase) } msg.Data().(WithCodec).SetCodec(wm) - return c.codec.Marshal(ctx, msg, out) + return thriftCodec.Marshal(ctx, msg, out) } +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *jsonThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err @@ -119,29 +153,13 @@ func (c *jsonThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in return perrors.NewProtocolErrorWithMsg("get parser ServiceDescriptor failed") } - rm := thrift.NewReadJSON(svcDsc, msg.RPCRole() == remote.Client) + rm := thrift.NewReadJSON(svcDsc) rm.SetBinaryWithBase64(c.binaryWithBase64) // Transport protocol should be TTHeader, Framed, or TTHeaderFramed to enable dynamicgo if c.dynamicgoEnabled && msg.PayloadLen() != 0 { - rm.SetDynamicGo(&c.convOpts, &c.convOptsWithException, msg) + rm.SetDynamicGo(&c.convOpts, &c.convOptsWithException) } msg.Data().(WithCodec).SetCodec(rm) - return c.codec.Unmarshal(ctx, msg, in) -} - -func (c *jsonThriftCodec) getMethod(req interface{}, method string) (*Method, error) { - fnSvc, err := c.svcDsc.Load().(*descriptor.ServiceDescriptor).LookupFunctionByMethod(method) - if err != nil { - return nil, err - } - return &Method{method, fnSvc.Oneway}, nil -} - -func (c *jsonThriftCodec) Name() string { - return "JSONThrift" -} - -func (c *jsonThriftCodec) Close() error { - return c.provider.Close() + return thriftCodec.Unmarshal(ctx, msg, in) } diff --git a/pkg/generic/jsonthrift_codec_test.go b/pkg/generic/jsonthrift_codec_test.go index 8aa3375681..ecdc351889 100644 --- a/pkg/generic/jsonthrift_codec_test.go +++ b/pkg/generic/jsonthrift_codec_test.go @@ -17,25 +17,19 @@ package generic import ( - "context" "testing" "github.com/cloudwego/dynamicgo/conv" - "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/transport" + "github.com/cloudwego/kitex/pkg/generic/thrift" ) func TestJsonThriftCodec(t *testing.T) { // without dynamicgo p, err := NewThriftFileProvider("./json_test/idl/mock.thrift") test.Assert(t, err == nil) - gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} - jtc, err := newJsonThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) + jtc := newJsonThriftCodec(p, nil) test.Assert(t, !jtc.dynamicgoEnabled) test.DeepEqual(t, jtc.convOpts, conv.Options{}) test.DeepEqual(t, jtc.convOptsWithThriftBase, conv.Options{}) @@ -46,87 +40,13 @@ func TestJsonThriftCodec(t *testing.T) { method, err := jtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, jtc.svcName == "Mock") - ctx := context.Background() - sendMsg := initJsonSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = jtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil, err) - - // Unmarshal side - recvMsg := initJsonRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = jtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil, err) -} - -func TestJsonThriftCodec_SelfRef_Old(t *testing.T) { - t.Run("old way", func(t *testing.T) { - p, err := NewThriftFileProvider("./json_test/idl/mock.thrift") - test.Assert(t, err == nil) - gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} - jtc, err := newJsonThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) - defer jtc.Close() - test.Assert(t, jtc.Name() == "JSONThrift") - - method, err := jtc.getMethod(nil, "Test") - test.Assert(t, err == nil) - test.Assert(t, method.Name == "Test") - - ctx := context.Background() - sendMsg := initJsonSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = jtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil, err) - - // Unmarshal side - recvMsg := initJsonRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = jtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil, err) - }) - - t.Run("old way", func(t *testing.T) { - p, err := NewThriftFileProviderWithDynamicGo("./json_test/idl/mock.thrift") - test.Assert(t, err == nil) - gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} - jtc, err := newJsonThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) - defer jtc.Close() - test.Assert(t, jtc.Name() == "JSONThrift") - - method, err := jtc.getMethod(nil, "Test") - test.Assert(t, err == nil) - test.Assert(t, method.Name == "Test") - - ctx := context.Background() - sendMsg := initJsonSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = jtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil, err) - - // Unmarshal side - recvMsg := initJsonRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = jtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil, err) - }) + rw := jtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) } func TestJsonThriftCodecWithDynamicGo(t *testing.T) { @@ -134,8 +54,7 @@ func TestJsonThriftCodecWithDynamicGo(t *testing.T) { p, err := NewThriftFileProviderWithDynamicGo("./json_test/idl/mock.thrift") test.Assert(t, err == nil) gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} - jtc, err := newJsonThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) + jtc := newJsonThriftCodec(p, gOpts) test.Assert(t, jtc.dynamicgoEnabled) test.DeepEqual(t, jtc.convOpts, DefaultJSONDynamicGoConvOpts) convOptsWithThriftBase := DefaultJSONDynamicGoConvOpts @@ -151,127 +70,49 @@ func TestJsonThriftCodecWithDynamicGo(t *testing.T) { test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") - ctx := context.Background() - sendMsg := initJsonSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = jtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil, err) - - // Unmarshal side - recvMsg := initJsonRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = jtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil, err) - - // disable unmarshal With dynamicgo because the payload length is 0 - sendMsg = initJsonSendMsg(transport.PurePayload) - - // Marshal side - out = remote.NewWriterBuffer(256) - err = jtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) - - // Unmarshal side - recvMsg = initJsonRecvMsg() - buf, err = out.Bytes() - test.Assert(t, err == nil) - in = remote.NewReaderBuffer(buf) - err = jtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) + rw := jtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) } func TestJsonThriftCodec_SelfRef(t *testing.T) { - p, err := NewThriftFileProvider("./json_test/idl/mock.thrift") - test.Assert(t, err == nil) - jtc, err := newJsonThriftCodec(p, thriftCodec, nil) - test.Assert(t, err == nil) - defer jtc.Close() - test.Assert(t, jtc.Name() == "JSONThrift") - - method, err := jtc.getMethod(nil, "Test") - test.Assert(t, err == nil) - test.Assert(t, method.Name == "Test") - - ctx := context.Background() - sendMsg := initJsonSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = jtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil, err) - - // Unmarshal side - recvMsg := initJsonRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = jtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil, err) -} - -func TestJsonExceptionError(t *testing.T) { - // without dynamicgo - p, err := NewThriftFileProvider("./json_test/idl/mock.thrift") - test.Assert(t, err == nil) - gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} - jtc, err := newJsonThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) + t.Run("without_dynamicgo", func(t *testing.T) { + p, err := NewThriftFileProvider("./json_test/idl/self_ref.thrift") + test.Assert(t, err == nil) + jtc := newJsonThriftCodec(p, nil) + defer jtc.Close() + test.Assert(t, jtc.Name() == "JSONThrift") - ctx := context.Background() - out := remote.NewWriterBuffer(256) - // empty method test - emptyMethodInk := rpcinfo.NewInvocation("", "") - emptyMethodRi := rpcinfo.NewRPCInfo(nil, nil, emptyMethodInk, nil, nil) - emptyMethodMsg := remote.NewMessage(nil, nil, emptyMethodRi, remote.Exception, remote.Client) - err = jtc.Marshal(ctx, emptyMethodMsg, out) - test.Assert(t, err.Error() == "empty methodName in thrift Marshal") + method, err := jtc.getMethod(nil, "Test") + test.Assert(t, err == nil) + test.Assert(t, method.Name == "Test") - // Exception MsgType test - exceptionMsgTypeInk := rpcinfo.NewInvocation("", "Test") - exceptionMsgTypeRi := rpcinfo.NewRPCInfo(nil, nil, exceptionMsgTypeInk, nil, nil) - exceptionMsgTypeMsg := remote.NewMessage(&remote.TransError{}, nil, exceptionMsgTypeRi, remote.Exception, remote.Client) - err = jtc.Marshal(ctx, exceptionMsgTypeMsg, out) - test.Assert(t, err == nil) + rw := jtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) + }) - // with dynamicgo - p, err = NewThriftFileProviderWithDynamicGo("./json_test/idl/mock.thrift") - test.Assert(t, err == nil) - jtc, err = newJsonThriftCodec(p, thriftCodec, gOpts) - test.Assert(t, err == nil) - // empty method test - err = jtc.Marshal(ctx, emptyMethodMsg, out) - test.Assert(t, err.Error() == "empty methodName in thrift Marshal") - // Exception MsgType test - err = jtc.Marshal(ctx, exceptionMsgTypeMsg, out) - test.Assert(t, err == nil) -} + t.Run("with_dynamicgo", func(t *testing.T) { + p, err := NewThriftFileProviderWithDynamicGo("./json_test/idl/self_ref.thrift") + test.Assert(t, err == nil) + gOpts := &Options{dynamicgoConvOpts: DefaultJSONDynamicGoConvOpts} + jtc := newJsonThriftCodec(p, gOpts) + defer jtc.Close() + test.Assert(t, jtc.Name() == "JSONThrift") -func initJsonSendMsg(tp transport.Protocol) remote.Message { - req := &Args{ - Request: `{"extra": "Hello"}`, - Method: "Test", - } - svcInfo := mocks.ServiceInfo() - ink := rpcinfo.NewInvocation("", "Test") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) - msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) - return msg -} + method, err := jtc.getMethod(nil, "Test") + test.Assert(t, err == nil) + test.Assert(t, method.Name == "Test") + test.Assert(t, jtc.svcName == "Mock") -func initJsonRecvMsg() remote.Message { - req := &Args{ - Request: "Test", - Method: "Test", - } - ink := rpcinfo.NewInvocation("", "Test") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server) - return msg + rw := jtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) + }) } diff --git a/pkg/generic/mapthrift_codec.go b/pkg/generic/mapthrift_codec.go index 92272febfe..953734349f 100644 --- a/pkg/generic/mapthrift_codec.go +++ b/pkg/generic/mapthrift_codec.go @@ -19,7 +19,6 @@ package generic import ( "context" "errors" - "fmt" "sync/atomic" "github.com/cloudwego/kitex/pkg/generic/descriptor" @@ -29,41 +28,35 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var ( - _ remote.PayloadCodec = &mapThriftCodec{} - _ Closer = &mapThriftCodec{} -) +var _ Closer = &mapThriftCodec{} type mapThriftCodec struct { svcDsc atomic.Value // *idl provider DescriptorProvider - codec remote.PayloadCodec forJSON bool binaryWithBase64 bool binaryWithByteSlice bool setFieldsForEmptyStruct uint8 + svcName string } -func newMapThriftCodec(p DescriptorProvider, codec remote.PayloadCodec) (*mapThriftCodec, error) { +func newMapThriftCodec(p DescriptorProvider) *mapThriftCodec { svc := <-p.Provide() c := &mapThriftCodec{ - codec: codec, provider: p, binaryWithBase64: false, binaryWithByteSlice: false, + svcName: svc.Name, } c.svcDsc.Store(svc) go c.update() - return c, nil + return c } -func newMapThriftCodecForJSON(p DescriptorProvider, codec remote.PayloadCodec) (*mapThriftCodec, error) { - c, err := newMapThriftCodec(p, codec) - if err != nil { - return nil, err - } +func newMapThriftCodecForJSON(p DescriptorProvider) *mapThriftCodec { + c := newMapThriftCodec(p) c.forJSON = true - return c, nil + return c } func (c *mapThriftCodec) update() { @@ -72,63 +65,88 @@ func (c *mapThriftCodec) update() { if !ok { return } + c.svcName = svc.Name c.svcDsc.Store(svc) } } +func (c *mapThriftCodec) getMessageReaderWriter() interface{} { + svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) + if !ok { + return errors.New("get parser ServiceDescriptor failed") + } + var rw *thrift.StructReaderWriter + if c.forJSON { + rw = thrift.NewStructReaderWriterForJSON(svcDsc) + } else { + rw = thrift.NewStructReaderWriter(svcDsc) + } + c.configureStructWriter(rw.WriteStruct) + c.configureStructReader(rw.ReadStruct) + return rw +} + +func (c *mapThriftCodec) configureStructWriter(writer *thrift.WriteStruct) { + writer.SetBinaryWithBase64(c.binaryWithBase64) +} + +func (c *mapThriftCodec) configureStructReader(reader *thrift.ReadStruct) { + reader.SetBinaryOption(c.binaryWithBase64, c.binaryWithByteSlice) + reader.SetSetFieldsForEmptyStruct(c.setFieldsForEmptyStruct) +} + +func (c *mapThriftCodec) getMethod(req interface{}, method string) (*Method, error) { + fnSvc, err := c.svcDsc.Load().(*descriptor.ServiceDescriptor).LookupFunctionByMethod(method) + if err != nil { + return nil, err + } + return &Method{method, fnSvc.Oneway}, nil +} + +func (c *mapThriftCodec) Name() string { + return "MapThrift" +} + +func (c *mapThriftCodec) Close() error { + return c.provider.Close() +} + +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *mapThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { method := msg.RPCInfo().Invocation().MethodName() if method == "" { return errors.New("empty methodName in thrift Marshal") } if msg.MessageType() == remote.Exception { - return c.codec.Marshal(ctx, msg, out) + return thriftCodec.Marshal(ctx, msg, out) } svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { - return fmt.Errorf("get parser ServiceDescriptor failed") - } - wm, err := thrift.NewWriteStruct(svcDsc, method, msg.RPCRole() == remote.Client) - if err != nil { - return err + return errors.New("get parser ServiceDescriptor failed") } + wm := thrift.NewWriteStruct(svcDsc) wm.SetBinaryWithBase64(c.binaryWithBase64) msg.Data().(WithCodec).SetCodec(wm) - return c.codec.Marshal(ctx, msg, out) + return thriftCodec.Marshal(ctx, msg, out) } +// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter func (c *mapThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err } svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { - return fmt.Errorf("get parser ServiceDescriptor failed") + return errors.New("get parser ServiceDescriptor failed") } var rm *thrift.ReadStruct if c.forJSON { - rm = thrift.NewReadStructForJSON(svcDsc, msg.RPCRole() == remote.Client) + rm = thrift.NewReadStructForJSON(svcDsc) } else { - rm = thrift.NewReadStruct(svcDsc, msg.RPCRole() == remote.Client) + rm = thrift.NewReadStruct(svcDsc) } rm.SetBinaryOption(c.binaryWithBase64, c.binaryWithByteSlice) rm.SetSetFieldsForEmptyStruct(c.setFieldsForEmptyStruct) msg.Data().(WithCodec).SetCodec(rm) - return c.codec.Unmarshal(ctx, msg, in) -} - -func (c *mapThriftCodec) getMethod(req interface{}, method string) (*Method, error) { - fnSvc, err := c.svcDsc.Load().(*descriptor.ServiceDescriptor).LookupFunctionByMethod(method) - if err != nil { - return nil, err - } - return &Method{method, fnSvc.Oneway}, nil -} - -func (c *mapThriftCodec) Name() string { - return "MapThrift" -} - -func (c *mapThriftCodec) Close() error { - return c.provider.Close() + return thriftCodec.Unmarshal(ctx, msg, in) } diff --git a/pkg/generic/mapthrift_codec_test.go b/pkg/generic/mapthrift_codec_test.go index 497bad52a6..4d9b7948e9 100644 --- a/pkg/generic/mapthrift_codec_test.go +++ b/pkg/generic/mapthrift_codec_test.go @@ -17,188 +17,65 @@ package generic import ( - "context" - "reflect" "testing" - "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/transport" + "github.com/cloudwego/kitex/pkg/generic/thrift" ) func TestMapThriftCodec(t *testing.T) { p, err := NewThriftFileProvider("./map_test/idl/mock.thrift") test.Assert(t, err == nil) - mtc, err := newMapThriftCodec(p, thriftCodec) - test.Assert(t, err == nil) + mtc := newMapThriftCodec(p) defer mtc.Close() test.Assert(t, mtc.Name() == "MapThrift") method, err := mtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, mtc.svcName == "Mock") - ctx := context.Background() - sendMsg := initMapSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = mtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) - - // UnMarshal side - recvMsg := initMapRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = mtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) + rw := mtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) } func TestMapThriftCodecSelfRef(t *testing.T) { p, err := NewThriftFileProvider("./map_test/idl/self_ref.thrift") test.Assert(t, err == nil) - mtc, err := newMapThriftCodec(p, thriftCodec) - test.Assert(t, err == nil) + mtc := newMapThriftCodec(p) defer mtc.Close() test.Assert(t, mtc.Name() == "MapThrift") method, err := mtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, mtc.svcName == "Mock") - ctx := context.Background() - sendMsg := initNilMapSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(0) - err = mtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) - - // UnMarshal side - recvMsg := initMapRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = mtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) - exp := map[string]interface{}{ - "self": map[string]interface{}{}, - "extra": "", - } - act := recvMsg.Data().(*Args).Request - test.Assert(t, reflect.DeepEqual(exp, act)) + rw := mtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) } func TestMapThriftCodecForJSON(t *testing.T) { p, err := NewThriftFileProvider("./map_test/idl/mock.thrift") test.Assert(t, err == nil) - mtc, err := newMapThriftCodecForJSON(p, thriftCodec) - test.Assert(t, err == nil) + mtc := newMapThriftCodecForJSON(p) defer mtc.Close() test.Assert(t, mtc.Name() == "MapThrift") method, err := mtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, mtc.svcName == "Mock") - ctx := context.Background() - sendMsg := initMapSendMsg(transport.TTHeader) - - // Marshal side - out := remote.NewWriterBuffer(256) - err = mtc.Marshal(ctx, sendMsg, out) - test.Assert(t, err == nil) - - // UnMarshal side - recvMsg := initMapRecvMsg() - buf, err := out.Bytes() - test.Assert(t, err == nil) - recvMsg.SetPayloadLen(len(buf)) - in := remote.NewReaderBuffer(buf) - err = mtc.Unmarshal(ctx, recvMsg, in) - test.Assert(t, err == nil) - args, ok := recvMsg.Data().(*Args) + rw := mtc.getMessageReaderWriter() + _, ok := rw.(thrift.MessageWriter) test.Assert(t, ok) - fieldMap, ok := args.Request.(map[string]interface{}) + _, ok = rw.(thrift.MessageReader) test.Assert(t, ok) - _, ok = fieldMap["strMap"].(map[string]interface{}) - test.Assert(t, ok) -} - -func TestMapExceptionError(t *testing.T) { - p, err := NewThriftFileProvider("./map_test/idl/mock.thrift") - test.Assert(t, err == nil) - mtc, err := newMapThriftCodec(p, thriftCodec) - test.Assert(t, err == nil) - - ctx := context.Background() - out := remote.NewWriterBuffer(256) - // empty method test - emptyMethodInk := rpcinfo.NewInvocation("", "") - emptyMethodRi := rpcinfo.NewRPCInfo(nil, nil, emptyMethodInk, nil, nil) - emptyMethodMsg := remote.NewMessage(nil, nil, emptyMethodRi, remote.Exception, remote.Client) - // Marshal side - err = mtc.Marshal(ctx, emptyMethodMsg, out) - test.Assert(t, err.Error() == "empty methodName in thrift Marshal") - - // Exception MsgType test - exceptionMsgTypeInk := rpcinfo.NewInvocation("", "Test") - exceptionMsgTypeRi := rpcinfo.NewRPCInfo(nil, nil, exceptionMsgTypeInk, nil, nil) - exceptionMsgTypeMsg := remote.NewMessage(&remote.TransError{}, nil, exceptionMsgTypeRi, remote.Exception, remote.Client) - // Marshal side - err = mtc.Marshal(ctx, exceptionMsgTypeMsg, out) - test.Assert(t, err == nil) -} - -func initMapSendMsg(tp transport.Protocol) remote.Message { - req := &Args{ - Request: &descriptor.HTTPRequest{ - Body: map[string]interface{}{ - "strMap": map[string]interface{}{ - "Test1": "Test2", - }, - }, - }, - Method: "Test", - } - svcInfo := mocks.ServiceInfo() - ink := rpcinfo.NewInvocation("", "Test") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) - msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) - return msg -} - -func initNilMapSendMsg(tp transport.Protocol) remote.Message { - req := &Args{ - Request: &descriptor.HTTPRequest{ - Body: map[string]interface{}{ - "self": nil, - }, - }, - Method: "Test", - } - svcInfo := mocks.ServiceInfo() - ink := rpcinfo.NewInvocation("", "Test") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) - msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) - return msg -} - -func initMapRecvMsg() remote.Message { - req := &Args{ - Request: "Test", - Method: "Test", - } - ink := rpcinfo.NewInvocation("", "Test") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server) - return msg } diff --git a/pkg/generic/proto/json.go b/pkg/generic/proto/json.go index 3fecfc911c..738c788482 100644 --- a/pkg/generic/proto/json.go +++ b/pkg/generic/proto/json.go @@ -29,38 +29,33 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) -// NewWriteJSON build WriteJSON according to ServiceDescriptor -func NewWriteJSON(svc *dproto.ServiceDescriptor, method string, isClient bool, convOpts *conv.Options) (*WriteJSON, error) { - fnDsc := svc.LookupMethodByName(method) - if fnDsc == nil { - return nil, fmt.Errorf("missing method: %s in service: %s", method, svc.Name()) - } +type JSONReaderWriter struct { + *ReadJSON + *WriteJSON +} - // from the proto.ServiceDescriptor, get the TypeDescriptor - typeDescriptor := fnDsc.Input() - if !isClient { - typeDescriptor = fnDsc.Output() - } +func NewJsonReaderWriter(svc *dproto.ServiceDescriptor, convOpts *conv.Options) *JSONReaderWriter { + return &JSONReaderWriter{ReadJSON: NewReadJSON(svc, convOpts), WriteJSON: NewWriteJSON(svc, convOpts)} +} - ws := &WriteJSON{ +// NewWriteJSON build WriteJSON according to ServiceDescriptor +func NewWriteJSON(svc *dproto.ServiceDescriptor, convOpts *conv.Options) *WriteJSON { + return &WriteJSON{ + svcDsc: svc, dynamicgoConvOpts: convOpts, - dynamicgoTypeDsc: typeDescriptor, - isClient: isClient, } - return ws, nil } // WriteJSON implement of MessageWriter type WriteJSON struct { + svcDsc *dproto.ServiceDescriptor dynamicgoConvOpts *conv.Options - dynamicgoTypeDsc *dproto.TypeDescriptor - isClient bool } var _ MessageWriter = (*WriteJSON)(nil) // Write converts msg to protobuf wire format and returns an output bytebuffer -func (m *WriteJSON) Write(ctx context.Context, msg interface{}) (interface{}, error) { +func (m *WriteJSON) Write(ctx context.Context, msg interface{}, method string, isClient bool) (interface{}, error) { var s string if msg == nil { s = "{}" @@ -75,8 +70,19 @@ func (m *WriteJSON) Write(ctx context.Context, msg interface{}) (interface{}, er cv := dconvj2p.NewBinaryConv(*m.dynamicgoConvOpts) + fnDsc := m.svcDsc.LookupMethodByName(method) + if fnDsc == nil { + return nil, fmt.Errorf("missing method: %s in service: %s", method, m.svcDsc.Name()) + } + + // from the proto.ServiceDescriptor, get the TypeDescriptor + typeDsc := fnDsc.Input() + if !isClient { + typeDsc = fnDsc.Output() + } + // get protobuf-encode bytes - actualMsgBuf, err := cv.Do(ctx, m.dynamicgoTypeDsc, utils.StringToSliceByte(s)) + actualMsgBuf, err := cv.Do(ctx, typeDsc, utils.StringToSliceByte(s)) if err != nil { return nil, perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error())) } @@ -84,26 +90,24 @@ func (m *WriteJSON) Write(ctx context.Context, msg interface{}) (interface{}, er } // NewReadJSON build ReadJSON according to ServiceDescriptor -func NewReadJSON(svc *dproto.ServiceDescriptor, isClient bool, convOpts *conv.Options) (*ReadJSON, error) { +func NewReadJSON(svc *dproto.ServiceDescriptor, convOpts *conv.Options) *ReadJSON { // extract svc to be used to convert later return &ReadJSON{ dynamicgoConvOpts: convOpts, dynamicgoSvcDsc: svc, - isClient: isClient, - }, nil + } } // ReadJSON implement of MessageReaderWithMethod type ReadJSON struct { dynamicgoConvOpts *conv.Options dynamicgoSvcDsc *dproto.ServiceDescriptor - isClient bool } var _ MessageReader = (*ReadJSON)(nil) // Read reads data from actualMsgBuf and convert to json string -func (m *ReadJSON) Read(ctx context.Context, method string, actualMsgBuf []byte) (interface{}, error) { +func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, actualMsgBuf []byte) (interface{}, error) { // create dynamic message here, once method string has been extracted fnDsc := m.dynamicgoSvcDsc.LookupMethodByName(method) if fnDsc == nil { @@ -112,7 +116,7 @@ func (m *ReadJSON) Read(ctx context.Context, method string, actualMsgBuf []byte) // from the dproto.ServiceDescriptor, get the TypeDescriptor typeDescriptor := fnDsc.Output() - if !m.isClient { + if !isClient { typeDescriptor = fnDsc.Input() } diff --git a/pkg/generic/proto/json_test.go b/pkg/generic/proto/json_test.go index b357a98044..ccac9c4388 100644 --- a/pkg/generic/proto/json_test.go +++ b/pkg/generic/proto/json_test.go @@ -49,9 +49,6 @@ func TestWrite(t *testing.T) { svc, err := opts.NewDescriptorFromPath(context.Background(), example2IDLPath) test.Assert(t, err == nil) - wm, err := NewWriteJSON(svc, method, true, &conv.Options{}) - test.Assert(t, err == nil) - msg := getExampleReq() // get expected json struct @@ -59,7 +56,8 @@ func TestWrite(t *testing.T) { json.Unmarshal([]byte(msg), exp) // marshal json string into protobuf wire format using Write - out, err := wm.Write(context.Background(), msg) + wm := NewWriteJSON(svc, &conv.Options{}) + out, err := wm.Write(context.Background(), msg, method, true) test.Assert(t, err == nil) buf, ok := out.([]byte) test.Assert(t, ok) @@ -98,11 +96,9 @@ func TestRead(t *testing.T) { svc, err := opts.NewDescriptorFromPath(context.Background(), example2IDLPath) test.Assert(t, err == nil) - rm, err := NewReadJSON(svc, false, &conv.Options{}) - test.Assert(t, err == nil) - // unmarshal protobuf wire format into json string using Read - out, err := rm.Read(context.Background(), method, in) + rm := NewReadJSON(svc, &conv.Options{}) + out, err := rm.Read(context.Background(), method, false, in) test.Assert(t, err == nil) // get expected json struct diff --git a/pkg/generic/proto/protobuf.go b/pkg/generic/proto/protobuf.go index cd2941829c..dcadd86eed 100644 --- a/pkg/generic/proto/protobuf.go +++ b/pkg/generic/proto/protobuf.go @@ -22,10 +22,10 @@ import ( // MessageReader read from ActualMsgBuf with method and returns a string type MessageReader interface { - Read(ctx context.Context, method string, actualMsgBuf []byte) (interface{}, error) + Read(ctx context.Context, method string, isClient bool, actualMsgBuf []byte) (interface{}, error) } // MessageWriter writes to a converts json to protobufs wireformat and returns an output bytebuffer type MessageWriter interface { - Write(ctx context.Context, msg interface{}) (interface{}, error) + Write(ctx context.Context, msg interface{}, method string, isClient bool) (interface{}, error) } diff --git a/pkg/generic/thrift/http.go b/pkg/generic/thrift/http.go index 9d7b66e8ca..b14e674981 100644 --- a/pkg/generic/thrift/http.go +++ b/pkg/generic/thrift/http.go @@ -29,19 +29,25 @@ import ( "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" ) +type HTTPReaderWriter struct { + *ReadHTTPResponse + *WriteHTTPRequest +} + +func NewHTTPReaderWriter(svc *descriptor.ServiceDescriptor) *HTTPReaderWriter { + return &HTTPReaderWriter{ReadHTTPResponse: NewReadHTTPResponse(svc), WriteHTTPRequest: NewWriteHTTPRequest(svc)} +} + // WriteHTTPRequest implement of MessageWriter type WriteHTTPRequest struct { svc *descriptor.ServiceDescriptor - dynamicgoTypeDsc *dthrift.TypeDescriptor binaryWithBase64 bool convOpts conv.Options // used for dynamicgo conversion convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on - hasRequestBase bool dynamicgoEnabled bool } @@ -66,17 +72,10 @@ func (w *WriteHTTPRequest) SetBinaryWithBase64(enable bool) { } // SetDynamicGo ... -func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options, method string) error { +func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) { w.convOpts = *convOpts w.convOptsWithThriftBase = *convOptsWithThriftBase w.dynamicgoEnabled = true - fnDsc := w.svc.DynamicGoDsc.Functions()[method] - if fnDsc == nil { - return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, w.svc.DynamicGoDsc.Name()) - } - w.hasRequestBase = fnDsc.HasRequestBase() - w.dynamicgoTypeDsc = fnDsc.Request() - return nil } // originalWrite ... @@ -101,7 +100,6 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out thrift.TProtoc type ReadHTTPResponse struct { svc *descriptor.ServiceDescriptor base64Binary bool - msg remote.Message dynamicgoEnabled bool useRawBodyForHTTPResp bool t2jBinaryConv t2j.BinaryConv // used for dynamicgo thrift to json conversion @@ -127,24 +125,21 @@ func (r *ReadHTTPResponse) SetUseRawBodyForHTTPResp(useRawBodyForHTTPResp bool) } // SetDynamicGo ... -func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options, msg remote.Message) { +func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options) { r.t2jBinaryConv = t2j.NewBinaryConv(*convOpts) - r.msg = msg r.dynamicgoEnabled = true } // Read ... -func (r *ReadHTTPResponse) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { // fallback logic if !r.dynamicgoEnabled { return r.originalRead(ctx, method, in) } - tProt, ok := in.(*cthrift.BinaryProtocol) if !ok { return nil, perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol") } - mBeginLen := bthrift.Binary.MessageBeginLength(method, thrift.TMessageType(r.msg.MessageType()), r.msg.RPCInfo().Invocation().SeqID()) sName, err := in.ReadStructBegin() if err != nil { return nil, err @@ -156,7 +151,7 @@ func (r *ReadHTTPResponse) Read(ctx context.Context, method string, in thrift.TP return nil, err } fBeginLen := bthrift.Binary.FieldBeginLength(fName, typeId, id) - transBuf, err := tProt.ByteBuffer().ReadBinary(r.msg.PayloadLen() - mBeginLen - sBeginLen - fBeginLen - bthrift.Binary.MessageEndLength()) + transBuf, err := tProt.ByteBuffer().ReadBinary(dataLen - sBeginLen - fBeginLen) if err != nil { return nil, err } diff --git a/pkg/generic/thrift/http_fallback.go b/pkg/generic/thrift/http_fallback.go index 0d4ae7a2e3..4bda30b81b 100644 --- a/pkg/generic/thrift/http_fallback.go +++ b/pkg/generic/thrift/http_fallback.go @@ -26,6 +26,6 @@ import ( ) // Write ... -func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { +func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { return w.originalWrite(ctx, out, msg, requestBase) } diff --git a/pkg/generic/thrift/http_go116plus_amd64.go b/pkg/generic/thrift/http_go116plus_amd64.go index 3b6a82683e..e0701ad6b6 100644 --- a/pkg/generic/thrift/http_go116plus_amd64.go +++ b/pkg/generic/thrift/http_go116plus_amd64.go @@ -21,6 +21,7 @@ package thrift import ( "context" + "fmt" "unsafe" "github.com/bytedance/gopkg/lang/mcache" @@ -35,7 +36,7 @@ import ( ) // Write ... -func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { +func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { // fallback logic if !w.dynamicgoEnabled { return w.originalWrite(ctx, out, msg, requestBase) @@ -44,8 +45,14 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg // dynamicgo logic req := msg.(*descriptor.HTTPRequest) + fnDsc := w.svc.DynamicGoDsc.Functions()[method] + if fnDsc == nil { + return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, w.svc.DynamicGoDsc.Name()) + } + dynamicgoTypeDsc := fnDsc.Request() + var cv j2t.BinaryConv - if !w.hasRequestBase { + if !fnDsc.HasRequestBase() { requestBase = nil } if requestBase != nil { @@ -56,7 +63,7 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg cv = j2t.NewBinaryConv(w.convOpts) } - if err := out.WriteStructBegin(w.dynamicgoTypeDsc.Struct().Name()); err != nil { + if err := out.WriteStructBegin(dynamicgoTypeDsc.Struct().Name()); err != nil { return err } @@ -65,7 +72,7 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg dbuf := mcache.Malloc(len(body))[0:0] defer mcache.Free(dbuf) - for _, field := range w.dynamicgoTypeDsc.Struct().Fields() { + for _, field := range dynamicgoTypeDsc.Struct().Fields() { if err := out.WriteFieldBegin(field.Name(), field.Type().Type().ToThriftTType(), int16(field.ID())); err != nil { return err } diff --git a/pkg/generic/thrift/http_pb.go b/pkg/generic/thrift/http_pb.go index f0b7418088..23f5aaa5f5 100644 --- a/pkg/generic/thrift/http_pb.go +++ b/pkg/generic/thrift/http_pb.go @@ -29,6 +29,15 @@ import ( thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) +type HTTPPbReaderWriter struct { + *ReadHTTPPbResponse + *WriteHTTPPbRequest +} + +func NewHTTPPbReaderWriter(svc *descriptor.ServiceDescriptor, pbsvc proto.ServiceDescriptor) *HTTPPbReaderWriter { + return &HTTPPbReaderWriter{ReadHTTPPbResponse: NewReadHTTPPbResponse(svc, pbsvc), WriteHTTPPbRequest: NewWriteHTTPPbRequest(svc, pbsvc)} +} + // WriteHTTPPbRequest implement of MessageWriter type WriteHTTPPbRequest struct { svc *descriptor.ServiceDescriptor @@ -44,7 +53,7 @@ func NewWriteHTTPPbRequest(svc *descriptor.ServiceDescriptor, pbSvc *desc.Servic } // Write ... -func (w *WriteHTTPPbRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { +func (w *WriteHTTPPbRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { req := msg.(*descriptor.HTTPRequest) fn, err := w.svc.Router.Lookup(req) if err != nil { @@ -69,22 +78,22 @@ func (w *WriteHTTPPbRequest) Write(ctx context.Context, out thrift.TProtocol, ms return wrapStructWriter(ctx, req, out, fn.Request, &writerOption{requestBase: requestBase}) } -// ReadHTTPResponse implement of MessageReaderWithMethod +// ReadHTTPPbResponse implement of MessageReaderWithMethod type ReadHTTPPbResponse struct { svc *descriptor.ServiceDescriptor pbSvc proto.ServiceDescriptor } -var _ MessageReader = (*ReadHTTPResponse)(nil) +var _ MessageReader = (*ReadHTTPPbResponse)(nil) -// NewReadHTTPResponse ... +// NewReadHTTPPbResponse ... // Base64 encoding for binary is enabled by default. func NewReadHTTPPbResponse(svc *descriptor.ServiceDescriptor, pbSvc proto.ServiceDescriptor) *ReadHTTPPbResponse { return &ReadHTTPPbResponse{svc, pbSvc} } // Read ... -func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { fnDsc, err := r.svc.LookupFunctionByMethod(method) if err != nil { return nil, err diff --git a/pkg/generic/thrift/json.go b/pkg/generic/thrift/json.go index bcb9b83471..6c2be7fca8 100644 --- a/pkg/generic/thrift/json.go +++ b/pkg/generic/thrift/json.go @@ -29,32 +29,28 @@ import ( "github.com/tidwall/gjson" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/utils" ) +type JSONReaderWriter struct { + *ReadJSON + *WriteJSON +} + +func NewJsonReaderWriter(svc *descriptor.ServiceDescriptor) *JSONReaderWriter { + return &JSONReaderWriter{ReadJSON: NewReadJSON(svc), WriteJSON: NewWriteJSON(svc)} +} + // NewWriteJSON build WriteJSON according to ServiceDescriptor -func NewWriteJSON(svc *descriptor.ServiceDescriptor, method string, isClient bool) (*WriteJSON, error) { - fnDsc, err := svc.LookupFunctionByMethod(method) - if err != nil { - return nil, err - } - ty := fnDsc.Request - if !isClient { - ty = fnDsc.Response - } - ws := &WriteJSON{ - typeDsc: ty, - hasRequestBase: fnDsc.HasRequestBase && isClient, +func NewWriteJSON(svc *descriptor.ServiceDescriptor) *WriteJSON { + return &WriteJSON{ + svcDsc: svc, base64Binary: true, - isClient: isClient, dynamicgoEnabled: false, } - return ws, nil } const voidWholeLen = 5 @@ -63,11 +59,8 @@ var _ = wrapJSONWriter // WriteJSON implement of MessageWriter type WriteJSON struct { - typeDsc *descriptor.TypeDescriptor - dynamicgoTypeDsc *dthrift.TypeDescriptor - hasRequestBase bool + svcDsc *descriptor.ServiceDescriptor base64Binary bool - isClient bool convOpts conv.Options // used for dynamicgo conversion convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on dynamicgoEnabled bool @@ -82,30 +75,30 @@ func (m *WriteJSON) SetBase64Binary(enable bool) { } // SetDynamicGo ... -func (m *WriteJSON) SetDynamicGo(svc *descriptor.ServiceDescriptor, method string, convOpts, convOptsWithThriftBase *conv.Options) error { - fnDsc := svc.DynamicGoDsc.Functions()[method] - if fnDsc == nil { - return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, svc.DynamicGoDsc.Name()) - } - if m.isClient { - m.dynamicgoTypeDsc = fnDsc.Request() - } else { - m.dynamicgoTypeDsc = fnDsc.Response() - } +func (m *WriteJSON) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) { m.convOpts = *convOpts m.convOptsWithThriftBase = *convOptsWithThriftBase m.dynamicgoEnabled = true - return nil } -func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { - if !m.hasRequestBase { +func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { + fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) + if err != nil { + return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, m.svcDsc.DynamicGoDsc.Name()) + } + typeDsc := fnDsc.Request + if !isClient { + typeDsc = fnDsc.Response + } + + hasRequestBase := fnDsc.HasRequestBase && isClient + if !hasRequestBase { requestBase = nil } // msg is void or nil if _, ok := msg.(descriptor.Void); ok || msg == nil { - return wrapStructWriter(ctx, msg, out, m.typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}) + return wrapStructWriter(ctx, msg, out, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}) } // msg is string @@ -124,14 +117,13 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg Index: 0, } } - return wrapJSONWriter(ctx, &body, out, m.typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}) + return wrapJSONWriter(ctx, &body, out, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}) } // NewReadJSON build ReadJSON according to ServiceDescriptor -func NewReadJSON(svc *descriptor.ServiceDescriptor, isClient bool) *ReadJSON { +func NewReadJSON(svc *descriptor.ServiceDescriptor) *ReadJSON { return &ReadJSON{ svc: svc, - isClient: isClient, binaryWithBase64: true, dynamicgoEnabled: false, } @@ -139,12 +131,11 @@ func NewReadJSON(svc *descriptor.ServiceDescriptor, isClient bool) *ReadJSON { // ReadJSON implement of MessageReaderWithMethod type ReadJSON struct { - svc *descriptor.ServiceDescriptor - isClient bool - binaryWithBase64 bool - msg remote.Message - t2jBinaryConv t2j.BinaryConv // used for dynamicgo thrift to json conversion - dynamicgoEnabled bool + svc *descriptor.ServiceDescriptor + binaryWithBase64 bool + convOpts conv.Options // used for dynamicgo conversion + convOptsWithException conv.Options // used for dynamicgo conversion which also handles an exception field + dynamicgoEnabled bool } var _ MessageReader = (*ReadJSON)(nil) @@ -156,22 +147,17 @@ func (m *ReadJSON) SetBinaryWithBase64(enable bool) { } // SetDynamicGo ... -func (m *ReadJSON) SetDynamicGo(convOpts, convOptsWithException *conv.Options, msg remote.Message) { - m.msg = msg +func (m *ReadJSON) SetDynamicGo(convOpts, convOptsWithException *conv.Options) { m.dynamicgoEnabled = true - if m.isClient { - // set binary conv to handle an exception field - m.t2jBinaryConv = t2j.NewBinaryConv(*convOptsWithException) - } else { - m.t2jBinaryConv = t2j.NewBinaryConv(*convOpts) - } + m.convOpts = *convOpts + m.convOptsWithException = *convOptsWithException } // Read read data from in thrift.TProtocol and convert to json string -func (m *ReadJSON) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { // fallback logic - if !m.dynamicgoEnabled { - return m.originalRead(ctx, method, in) + if !m.dynamicgoEnabled || dataLen <= 0 { + return m.originalRead(ctx, method, isClient, in) } // dynamicgo logic @@ -184,10 +170,8 @@ func (m *ReadJSON) Read(ctx context.Context, method string, in thrift.TProtocol) if fnDsc == nil { return nil, fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, m.svc.DynamicGoDsc.Name()) } - var tyDsc *dthrift.TypeDescriptor - if m.msg.MessageType() == remote.Reply { - tyDsc = fnDsc.Response() - } else { + tyDsc := fnDsc.Response() + if !isClient { tyDsc = fnDsc.Request() } @@ -198,8 +182,7 @@ func (m *ReadJSON) Read(ctx context.Context, method string, in thrift.TProtocol) } resp = descriptor.Void{} } else { - msgBeginLen := bthrift.Binary.MessageBeginLength(method, thrift.TMessageType(m.msg.MessageType()), m.msg.RPCInfo().Invocation().SeqID()) - transBuff, err := tProt.ByteBuffer().ReadBinary(m.msg.PayloadLen() - msgBeginLen - bthrift.Binary.MessageEndLength()) + transBuff, err := tProt.ByteBuffer().ReadBinary(dataLen) if err != nil { return nil, err } @@ -207,7 +190,13 @@ func (m *ReadJSON) Read(ctx context.Context, method string, in thrift.TProtocol) // json size is usually 2 times larger than equivalent thrift data buf := dirtmake.Bytes(0, len(transBuff)*2) // thrift []byte to json []byte - if err := m.t2jBinaryConv.DoInto(ctx, tyDsc, transBuff, &buf); err != nil { + var t2jBinaryConv t2j.BinaryConv + if isClient { + t2jBinaryConv = t2j.NewBinaryConv(m.convOptsWithException) + } else { + t2jBinaryConv = t2j.NewBinaryConv(m.convOpts) + } + if err := t2jBinaryConv.DoInto(ctx, tyDsc, transBuff, &buf); err != nil { return nil, err } buf = removePrefixAndSuffix(buf) @@ -224,13 +213,13 @@ func (m *ReadJSON) Read(ctx context.Context, method string, in thrift.TProtocol) return resp, nil } -func (m *ReadJSON) originalRead(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, in thrift.TProtocol) (interface{}, error) { fnDsc, err := m.svc.LookupFunctionByMethod(method) if err != nil { return nil, err } fDsc := fnDsc.Response - if !m.isClient { + if !isClient { fDsc = fnDsc.Request } resp, err := skipStructReader(ctx, in, fDsc, &readerOption{forJSON: true, throwException: true, binaryWithBase64: m.binaryWithBase64}) diff --git a/pkg/generic/thrift/json_fallback.go b/pkg/generic/thrift/json_fallback.go index c05c548993..d4a6a72e13 100644 --- a/pkg/generic/thrift/json_fallback.go +++ b/pkg/generic/thrift/json_fallback.go @@ -26,6 +26,6 @@ import ( ) // Write write json string to out thrift.TProtocol -func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { - return m.originalWrite(ctx, out, msg, requestBase) +func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { + return m.originalWrite(ctx, out, msg, method, isClient, requestBase) } diff --git a/pkg/generic/thrift/json_go116plus_amd64.go b/pkg/generic/thrift/json_go116plus_amd64.go index a8515ba5f5..0c1f7b9210 100644 --- a/pkg/generic/thrift/json_go116plus_amd64.go +++ b/pkg/generic/thrift/json_go116plus_amd64.go @@ -21,6 +21,7 @@ package thrift import ( "context" + "fmt" "unsafe" "github.com/bytedance/gopkg/lang/mcache" @@ -37,15 +38,25 @@ import ( ) // Write write json string to out thrift.TProtocol -func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { +func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { // fallback logic if !m.dynamicgoEnabled { - return m.originalWrite(ctx, out, msg, requestBase) + return m.originalWrite(ctx, out, msg, method, isClient, requestBase) } // dynamicgo logic + fnDsc := m.svcDsc.DynamicGoDsc.Functions()[method] + if fnDsc == nil { + return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, m.svcDsc.DynamicGoDsc.Name()) + } + dynamicgoTypeDsc := fnDsc.Request() + if !isClient { + dynamicgoTypeDsc = fnDsc.Response() + } + hasRequestBase := fnDsc.HasRequestBase() && isClient + var cv j2t.BinaryConv - if !m.hasRequestBase { + if !hasRequestBase { requestBase = nil } if requestBase != nil { @@ -58,10 +69,10 @@ func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interfa // msg is void or nil if _, ok := msg.(descriptor.Void); ok || msg == nil { - if err := m.writeHead(out); err != nil { + if err := m.writeHead(out, dynamicgoTypeDsc); err != nil { return err } - if err := m.writeFields(ctx, out, nil, nil); err != nil { + if err := m.writeFields(ctx, out, dynamicgoTypeDsc, nil, nil, isClient); err != nil { return err } return writeTail(out) @@ -74,10 +85,10 @@ func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interfa } transBuff := utils.StringToSliceByte(s) - if err := m.writeHead(out); err != nil { + if err := m.writeHead(out, dynamicgoTypeDsc); err != nil { return err } - if err := m.writeFields(ctx, out, &cv, transBuff); err != nil { + if err := m.writeFields(ctx, out, dynamicgoTypeDsc, &cv, transBuff, isClient); err != nil { return err } return writeTail(out) @@ -90,13 +101,13 @@ const ( String ) -func (m *WriteJSON) writeFields(ctx context.Context, out thrift.TProtocol, cv *j2t.BinaryConv, transBuff []byte) error { +func (m *WriteJSON) writeFields(ctx context.Context, out thrift.TProtocol, dynamicgoTypeDsc *dthrift.TypeDescriptor, cv *j2t.BinaryConv, transBuff []byte, isClient bool) error { dbuf := mcache.Malloc(len(transBuff))[0:0] defer mcache.Free(dbuf) - for _, field := range m.dynamicgoTypeDsc.Struct().Fields() { + for _, field := range dynamicgoTypeDsc.Struct().Fields() { // Exception field - if !m.isClient && field.ID() != 0 { + if !isClient && field.ID() != 0 { // generic server ignore the exception, because no description for exception // generic handler just return error continue @@ -136,8 +147,8 @@ func (m *WriteJSON) writeFields(ctx context.Context, out thrift.TProtocol, cv *j return nil } -func (m *WriteJSON) writeHead(out thrift.TProtocol) error { - if err := out.WriteStructBegin(m.dynamicgoTypeDsc.Struct().Name()); err != nil { +func (m *WriteJSON) writeHead(out thrift.TProtocol, dynamicgoTypeDsc *dthrift.TypeDescriptor) error { + if err := out.WriteStructBegin(dynamicgoTypeDsc.Struct().Name()); err != nil { return err } return nil diff --git a/pkg/generic/thrift/struct.go b/pkg/generic/thrift/struct.go index d3815cc690..d8870436f0 100644 --- a/pkg/generic/thrift/struct.go +++ b/pkg/generic/thrift/struct.go @@ -23,27 +23,27 @@ import ( thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) +type StructReaderWriter struct { + *ReadStruct + *WriteStruct +} + +func NewStructReaderWriter(svc *descriptor.ServiceDescriptor) *StructReaderWriter { + return &StructReaderWriter{ReadStruct: NewReadStruct(svc), WriteStruct: NewWriteStruct(svc)} +} + +func NewStructReaderWriterForJSON(svc *descriptor.ServiceDescriptor) *StructReaderWriter { + return &StructReaderWriter{ReadStruct: NewReadStructForJSON(svc), WriteStruct: NewWriteStruct(svc)} +} + // NewWriteStruct ... -func NewWriteStruct(svc *descriptor.ServiceDescriptor, method string, isClient bool) (*WriteStruct, error) { - fnDsc, err := svc.LookupFunctionByMethod(method) - if err != nil { - return nil, err - } - ty := fnDsc.Request - if !isClient { - ty = fnDsc.Response - } - ws := &WriteStruct{ - ty: ty, - hasRequestBase: fnDsc.HasRequestBase && isClient, - } - return ws, nil +func NewWriteStruct(svc *descriptor.ServiceDescriptor) *WriteStruct { + return &WriteStruct{svcDsc: svc} } // WriteStruct implement of MessageWriter type WriteStruct struct { - ty *descriptor.TypeDescriptor - hasRequestBase bool + svcDsc *descriptor.ServiceDescriptor binaryWithBase64 bool } @@ -56,33 +56,38 @@ func (m *WriteStruct) SetBinaryWithBase64(enable bool) { } // Write ... -func (m *WriteStruct) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { - if !m.hasRequestBase { +func (m *WriteStruct) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { + fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) + if err != nil { + return err + } + ty := fnDsc.Request + if !isClient { + ty = fnDsc.Response + } + hasRequestBase := fnDsc.HasRequestBase && isClient + + if !hasRequestBase { requestBase = nil } - return wrapStructWriter(ctx, msg, out, m.ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64}) + return wrapStructWriter(ctx, msg, out, ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64}) } // NewReadStruct ... -func NewReadStruct(svc *descriptor.ServiceDescriptor, isClient bool) *ReadStruct { - return &ReadStruct{ - svc: svc, - isClient: isClient, - } +func NewReadStruct(svc *descriptor.ServiceDescriptor) *ReadStruct { + return &ReadStruct{svc: svc} } -func NewReadStructForJSON(svc *descriptor.ServiceDescriptor, isClient bool) *ReadStruct { +func NewReadStructForJSON(svc *descriptor.ServiceDescriptor) *ReadStruct { return &ReadStruct{ - svc: svc, - isClient: isClient, - forJSON: true, + svc: svc, + forJSON: true, } } // ReadStruct implement of MessageReaderWithMethod type ReadStruct struct { svc *descriptor.ServiceDescriptor - isClient bool forJSON bool binaryWithBase64 bool binaryWithByteSlice bool @@ -108,13 +113,13 @@ func (m *ReadStruct) SetSetFieldsForEmptyStruct(mode uint8) { } // Read ... -func (m *ReadStruct) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { fnDsc, err := m.svc.LookupFunctionByMethod(method) if err != nil { return nil, err } fDsc := fnDsc.Response - if !m.isClient { + if !isClient { fDsc = fnDsc.Request } return skipStructReader(ctx, in, fDsc, &readerOption{throwException: true, forJSON: m.forJSON, binaryWithBase64: m.binaryWithBase64, binaryWithByteSlice: m.binaryWithByteSlice, setFieldsForEmptyStruct: m.setFieldsForEmptyStruct}) diff --git a/pkg/generic/thrift/thrift.go b/pkg/generic/thrift/thrift.go index aebf71d23e..6fca7568d7 100644 --- a/pkg/generic/thrift/thrift.go +++ b/pkg/generic/thrift/thrift.go @@ -29,10 +29,10 @@ const ( // MessageReader read from thrift.TProtocol with method type MessageReader interface { - Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) + Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) } // MessageWriter write to thrift.TProtocol type MessageWriter interface { - Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error + Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error } diff --git a/pkg/remote/codec/protobuf/protobuf.go b/pkg/remote/codec/protobuf/protobuf.go index a5a11e0aff..04e4418630 100644 --- a/pkg/remote/codec/protobuf/protobuf.go +++ b/pkg/remote/codec/protobuf/protobuf.go @@ -89,12 +89,10 @@ func (c protobufCodec) Marshal(ctx context.Context, message remote.Message, out // 4. write actual message buf msg, ok := data.(ProtobufMsgCodec) if !ok { - // If Using Generics - // if data is a MessageWriterWithContext - // Do msg.WritePb(ctx context.Context, out remote.ByteBuffer) + // Generic Case genmsg, isgen := data.(MessageWriterWithContext) if isgen { - actualMsg, err := genmsg.WritePb(ctx) + actualMsg, err := genmsg.WritePb(ctx, methodName) if err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error())) } @@ -197,7 +195,7 @@ func (c protobufCodec) Unmarshal(ctx context.Context, message remote.Message, in } } - // JSONPB Generic Case + // Generic Case if msg, ok := data.(MessageReaderWithMethodWithContext); ok { err := msg.ReadPb(ctx, methodName, actualMsgBuf) if err != nil { @@ -222,7 +220,7 @@ func (c protobufCodec) Name() string { // MessageWriterWithContext writes to output bytebuffer type MessageWriterWithContext interface { - WritePb(ctx context.Context) (interface{}, error) + WritePb(ctx context.Context, method string) (interface{}, error) } // MessageReaderWithMethodWithContext read from ActualMsgBuf with method diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 3d6d5c44ee..fe7ffe68c4 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -117,7 +117,7 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re } // fallback to old thrift way (slow) - if err = encodeBasicThrift(out, ctx, methodName, msgType, seqID, data); err == nil || err != errEncodeMismatchMsgType { + if err = encodeBasicThrift(out, ctx, methodName, msgType, seqID, data, message.RPCRole()); err == nil || err != errEncodeMismatchMsgType { return err } @@ -154,7 +154,7 @@ func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.M } // encodeBasicThrift encode with the old thrift way (slow) -func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error { +func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}, rpcRole remote.RPCRole) error { if err := verifyMarshalBasicThriftDataType(data); err != nil { return err } @@ -162,7 +162,7 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string if err := tProt.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID); err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error())) } - if err := marshalBasicThriftData(ctx, tProt, data); err != nil { + if err := marshalBasicThriftData(ctx, tProt, data, method, rpcRole); err != nil { return err } if err := tProt.WriteMessageEnd(); err != nil { @@ -203,7 +203,7 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r ri := message.RPCInfo() rpcinfo.Record(ctx, ri, stats.WaitReadStart, nil) - err = c.unmarshalThriftData(ctx, tProt, methodName, data, dataLen) + err = c.unmarshalThriftData(ctx, tProt, methodName, data, message.RPCRole(), dataLen) rpcinfo.Record(ctx, ri, stats.WaitReadFinish, err) if err != nil { return err @@ -239,9 +239,9 @@ func (c thriftCodec) Name() string { return serviceinfo.Thrift.String() } -// MessageWriterWithContext write to thrift.TProtocol -type MessageWriterWithContext interface { - Write(ctx context.Context, oprot thrift.TProtocol) error +// MessageWriterWithMethodWithContext write to thrift.TProtocol +type MessageWriterWithMethodWithContext interface { + Write(ctx context.Context, method string, oprot thrift.TProtocol) error } // MessageWriter write to thrift.TProtocol @@ -256,7 +256,7 @@ type MessageReader interface { // MessageReaderWithMethodWithContext read from thrift.TProtocol with method type MessageReaderWithMethodWithContext interface { - Read(ctx context.Context, method string, oprot thrift.TProtocol) error + Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error } type ThriftMsgFastCodec interface { diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index c17340b22c..30b707f4da 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -72,7 +72,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ // fallback to old thrift way (slow) transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize) tProt := thrift.NewTBinaryProtocol(transport, true, true) - if err := marshalBasicThriftData(ctx, tProt, data); err != nil { + if err := marshalBasicThriftData(ctx, tProt, data, "", -1); err != nil { return nil, err } return transport.Bytes(), nil @@ -82,7 +82,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ func verifyMarshalBasicThriftDataType(data interface{}) error { switch data.(type) { case MessageWriter: - case MessageWriterWithContext: + case MessageWriterWithMethodWithContext: default: return errEncodeMismatchMsgType } @@ -91,14 +91,14 @@ func verifyMarshalBasicThriftDataType(data interface{}) error { // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) // It uses the old thrift way which is much slower than FastCodec and Frugal -func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}) error { +func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}, method string, rpcRole remote.RPCRole) error { switch msg := data.(type) { case MessageWriter: if err := msg.Write(tProt); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) } - case MessageWriterWithContext: - if err := msg.Write(ctx, tProt); err != nil { + case MessageWriterWithMethodWithContext: + if err := msg.Write(ctx, method, tProt); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) } default: @@ -131,7 +131,7 @@ func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method c = defaultCodec } tProt := NewBinaryProtocol(remote.NewReaderBuffer(buf)) - err := c.unmarshalThriftData(ctx, tProt, method, data, len(buf)) + err := c.unmarshalThriftData(ctx, tProt, method, data, -1, len(buf)) if err == nil { tProt.Recycle() } @@ -176,7 +176,7 @@ func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, data // unmarshalThriftData only decodes the data (after methodName, msgType and seqId) // method is only used for generic calls -func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProtocol, method string, data interface{}, dataLen int) error { +func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProtocol, method string, data interface{}, rpcRole remote.RPCRole, dataLen int) error { // decode with hyper unmarshal if c.hyperMessageUnmarshalEnabled() && c.hyperMessageUnmarshalAvailable(data, dataLen) { return c.hyperUnmarshal(tProt, data, dataLen) @@ -197,7 +197,7 @@ func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProto } // fallback to old thrift way (slow) - return decodeBasicThriftData(ctx, tProt, method, data) + return decodeBasicThriftData(ctx, tProt, method, rpcRole, dataLen, data) } func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { @@ -234,7 +234,7 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error { } // decodeBasicThriftData decode thrift body the old way (slow) -func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method string, data interface{}) error { +func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method string, rpcRole remote.RPCRole, dataLen int, data interface{}) error { var err error switch t := data.(type) { case MessageReader: @@ -243,7 +243,7 @@ func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method s } case MessageReaderWithMethodWithContext: // methodName is necessary for generic calls to methodInfo from serviceInfo - if err = t.Read(ctx, method, tProt); err != nil { + if err = t.Read(ctx, method, dataLen, tProt); err != nil { return remote.NewTransError(remote.ProtocolError, err) } default: diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index 2426e84e32..b5fe3dcd43 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -42,13 +42,13 @@ var ( func TestMarshalBasicThriftData(t *testing.T) { t.Run("invalid-data", func(t *testing.T) { - err := marshalBasicThriftData(context.Background(), nil, 0) + err := marshalBasicThriftData(context.Background(), nil, 0, "", -1) test.Assert(t, err == errEncodeMismatchMsgType, err) }) t.Run("valid-data", func(t *testing.T) { transport := thrift.NewTMemoryBufferLen(1024) tProt := thrift.NewTBinaryProtocol(transport, true, true) - err := marshalBasicThriftData(context.Background(), tProt, mockReq) + err := marshalBasicThriftData(context.Background(), tProt, mockReq, "", -1) test.Assert(t, err == nil, err) result := transport.Bytes() test.Assert(t, reflect.DeepEqual(result, mockReqThrift), result) @@ -78,19 +78,19 @@ func Test_decodeBasicThriftData(t *testing.T) { t.Run("empty-input", func(t *testing.T) { req := &fast.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{})) - err := decodeBasicThriftData(context.Background(), tProt, "mock", req) + err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) test.Assert(t, err != nil, err) }) t.Run("invalid-input", func(t *testing.T) { req := &fast.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{0xff})) - err := decodeBasicThriftData(context.Background(), tProt, "mock", req) + err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) test.Assert(t, err != nil, err) }) t.Run("normal-input", func(t *testing.T) { req := &fast.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) - err := decodeBasicThriftData(context.Background(), tProt, "mock", req) + err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) checkDecodeResult(t, err, req) }) } @@ -128,7 +128,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, 0) + err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) checkDecodeResult(t, err, &fast.MockReq{ Msg: req.Msg, StrList: req.StrList, @@ -152,7 +152,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultMockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, 0) + err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in FastCodec using SkipDecoder Buffer")) }) diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index 1a17185ee1..f2546da553 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -240,7 +240,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, 0) + err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) checkDecodeResult(t, err, &fast.MockReq{ Msg: req.Msg, StrList: req.StrList, @@ -264,7 +264,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultMockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, 0) + err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in Frugal using SkipDecoder Buffer")) }) diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 23554ea39a..5a12d7b0df 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -62,20 +62,20 @@ func init() { } type mockWithContext struct { - ReadFunc func(ctx context.Context, method string, oprot thrift.TProtocol) error - WriteFunc func(ctx context.Context, oprot thrift.TProtocol) error + ReadFunc func(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error + WriteFunc func(ctx context.Context, method string, oprot thrift.TProtocol) error } -func (m *mockWithContext) Read(ctx context.Context, method string, oprot thrift.TProtocol) error { +func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error { if m.ReadFunc != nil { - return m.ReadFunc(ctx, method, oprot) + return m.ReadFunc(ctx, method, dataLen, oprot) } return nil } -func (m *mockWithContext) Write(ctx context.Context, oprot thrift.TProtocol) error { +func (m *mockWithContext) Write(ctx context.Context, method string, oprot thrift.TProtocol) error { if m.WriteFunc != nil { - return m.WriteFunc(ctx, oprot) + return m.WriteFunc(ctx, method, oprot) } return nil } @@ -85,7 +85,7 @@ func TestWithContext(t *testing.T) { t.Run(tb.Name, func(t *testing.T) { ctx := context.Background() - req := &mockWithContext{WriteFunc: func(ctx context.Context, oprot thrift.TProtocol) error { + req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot thrift.TProtocol) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") @@ -98,7 +98,9 @@ func TestWithContext(t *testing.T) { buf.Flush() { - resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, oprot thrift.TProtocol) error { return nil }} + resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error { + return nil + }} ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) msg := remote.NewMessage(resp, svcInfo, ri, remote.Call, remote.Client) diff --git a/pkg/remote/message.go b/pkg/remote/message.go index ec33de8030..c2401fb210 100644 --- a/pkg/remote/message.go +++ b/pkg/remote/message.go @@ -186,7 +186,7 @@ func (m *message) ServiceInfo() *serviceinfo.ServiceInfo { } func (m *message) SpecifyServiceInfo(svcName, methodName string) (*serviceinfo.ServiceInfo, error) { - // for non-multi-service including generic server scenario + // for single service scenario if m.targetSvcInfo != nil { if mt := m.targetSvcInfo.MethodInfo(methodName); mt == nil { return nil, NewTransErrorWithMsg(UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) diff --git a/pkg/serviceinfo/serviceinfo.go b/pkg/serviceinfo/serviceinfo.go index 422ca02482..96fd4a7cab 100644 --- a/pkg/serviceinfo/serviceinfo.go +++ b/pkg/serviceinfo/serviceinfo.go @@ -87,7 +87,7 @@ func (i *ServiceInfo) MethodInfo(name string) MethodInfo { if i == nil { return nil } - if i.ServiceName == GenericService { + if _, ok := i.Extra["generic"]; ok { if i.GenericMethod != nil { return i.GenericMethod(name) } diff --git a/server/genericserver/server.go b/server/genericserver/server.go index 5ed44989f1..7823bf3c3e 100644 --- a/server/genericserver/server.go +++ b/server/genericserver/server.go @@ -25,7 +25,7 @@ import ( // NewServer creates a generic server with the given handler and options. func NewServer(handler generic.Service, g generic.Generic, opts ...server.Option) server.Server { - svcInfo := generic.ServiceInfo(g.PayloadCodecType()) + svcInfo := generic.ServiceInfoWithCodec(g) return NewServerWithServiceInfo(handler, g, svcInfo, opts...) } diff --git a/server/genericserver/server_test.go b/server/genericserver/server_test.go index 756a2d34eb..8c6b10fcea 100644 --- a/server/genericserver/server_test.go +++ b/server/genericserver/server_test.go @@ -55,7 +55,7 @@ func TestNewServerWithServiceInfo(t *testing.T) { }) test.PanicAt(t, func() { - NewServerWithServiceInfo(nil, g, generic.ServiceInfo(g.PayloadCodecType())) + NewServerWithServiceInfo(nil, g, generic.ServiceInfoWithCodec(g)) }, func(err interface{}) bool { if errMsg, ok := err.(error); ok { return strings.Contains(errMsg.Error(), "handler is nil.") From 1e5b1fcaa759f8e9ff1617717e827aff1618b671 Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Tue, 25 Jun 2024 11:21:52 +0800 Subject: [PATCH 03/70] chore: upgrade go directive version to 1.17 of go.mod (#1415) --- go.mod | 28 +++++++++++++++++++++++++++- go.sum | 1 - 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 5abb6405d7..b6db0d6bbf 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/cloudwego/kitex -go 1.13 +go 1.17 require ( github.com/apache/thrift v0.13.0 @@ -28,3 +28,29 @@ require ( google.golang.org/protobuf v1.28.1 gopkg.in/yaml.v3 v3.0.1 ) + +require ( + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/fatih/structtag v1.2.0 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect + github.com/iancoleman/strcase v0.2.0 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/oleiade/lane v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect + github.com/smartystreets/goconvey v1.6.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + golang.org/x/arch v0.2.0 // indirect + golang.org/x/text v0.13.0 // indirect +) diff --git a/go.sum b/go.sum index c352926d2a..11e3173e31 100644 --- a/go.sum +++ b/go.sum @@ -213,7 +213,6 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= From 12c48e2a2216d4fa2aa8d19750bf5a94ccb56a74 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Tue, 2 Jul 2024 13:27:39 +0800 Subject: [PATCH 04/70] refactor: get rid of apache TApplicationException (#1389) --- client/middlewares.go | 28 +++++++--- client/middlewares_test.go | 16 +++--- client/mocks_test.go | 2 +- pkg/protocol/bthrift/exception.go | 61 ++++++++++++++++++++- pkg/protocol/bthrift/exception_test.go | 22 +++++++- pkg/remote/codec/thrift/thrift.go | 4 +- pkg/remote/codec/thrift/thrift_data.go | 2 +- pkg/remote/codec/thrift/thrift_data_test.go | 3 +- pkg/remote/codec/thrift/thrift_test.go | 7 ++- pkg/utils/thrift.go | 5 +- server/invoke_test.go | 3 +- 11 files changed, 119 insertions(+), 34 deletions(-) diff --git a/client/middlewares.go b/client/middlewares.go index 4f94a549c0..33c8c671ad 100644 --- a/client/middlewares.go +++ b/client/middlewares.go @@ -20,10 +20,10 @@ import ( "context" "errors" "fmt" + "reflect" + "strings" "time" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/internal" "github.com/cloudwego/kitex/pkg/discovery" "github.com/cloudwego/kitex/pkg/endpoint" @@ -159,13 +159,27 @@ func newIOErrorHandleMW(errHandle func(context.Context, error) error) endpoint.M } } +func isRemoteErr(err error) bool { + if err == nil { + return false + } + switch err.(type) { + // for thrift、KitexProtobuf, actually check *remote.TransError is enough + case *remote.TransError, protobuf.PBError: + return true + default: + // case thrift.TApplicationException ? + // XXX: we'd like to get rid of apache pkg, should be ok to check by type name + // for thrift v0.13.0, it's "*thrift.tApplicationException" + } + return strings.HasSuffix(reflect.TypeOf(err).String(), "ApplicationException") +} + // DefaultClientErrorHandler is Default ErrorHandler for client // when no ErrorHandler is specified with Option `client.WithErrorHandler`, this ErrorHandler will be injected. // for thrift、KitexProtobuf, >= v0.4.0 wrap protocol error to TransError, which will be more friendly. func DefaultClientErrorHandler(ctx context.Context, err error) error { - switch err.(type) { - // for thrift、KitexProtobuf, actually check *remote.TransError is enough - case *remote.TransError, thrift.TApplicationException, protobuf.PBError: + if isRemoteErr(err) { // Add 'remote' prefix to distinguish with local err. // Because it cannot make sure which side err when decode err happen return kerrors.ErrRemoteOrNetwork.WithCauseAndExtraMsg(err, "remote") @@ -176,9 +190,7 @@ func DefaultClientErrorHandler(ctx context.Context, err error) error { // ClientErrorHandlerWithAddr is ErrorHandler for client, which will add remote addr info into error func ClientErrorHandlerWithAddr(ctx context.Context, err error) error { addrStr := getRemoteAddr(ctx) - switch err.(type) { - // for thrift、KitexProtobuf, actually check *remote.TransError is enough - case *remote.TransError, thrift.TApplicationException, protobuf.PBError: + if isRemoteErr(err) { // Add 'remote' prefix to distinguish with local err. // Because it cannot make sure which side err when decode err happen extraMsg := "remote" diff --git a/client/middlewares_test.go b/client/middlewares_test.go index a328137853..75b096be00 100644 --- a/client/middlewares_test.go +++ b/client/middlewares_test.go @@ -22,7 +22,6 @@ import ( "net" "testing" - "github.com/apache/thrift/lib/go/thrift" "github.com/golang/mock/gomock" "github.com/cloudwego/kitex/internal/mocks" @@ -33,6 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/event" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/proxy" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" @@ -140,18 +140,18 @@ func TestDefaultErrorHandler(t *testing.T) { reqCtx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) // Test TApplicationException - err := DefaultClientErrorHandler(context.Background(), thrift.NewTApplicationException(100, "mock")) + err := DefaultClientErrorHandler(context.Background(), bthrift.NewApplicationException(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote]: mock", err.Error()) - var te thrift.TApplicationException + var te *bthrift.ApplicationException ok := errors.As(err, &te) test.Assert(t, ok) - test.Assert(t, te.TypeId() == 100) + test.Assert(t, te.TypeID() == 100) // Test TApplicationException with remote addr - err = ClientErrorHandlerWithAddr(reqCtx, thrift.NewTApplicationException(100, "mock")) + err = ClientErrorHandlerWithAddr(reqCtx, bthrift.NewApplicationException(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote-"+tcpAddrStr+"]: mock", err.Error()) ok = errors.As(err, &te) test.Assert(t, ok) - test.Assert(t, te.TypeId() == 100) + test.Assert(t, te.TypeID() == 100) // Test PbError err = DefaultClientErrorHandler(context.Background(), protobuf.NewPbError(100, "mock")) @@ -159,13 +159,13 @@ func TestDefaultErrorHandler(t *testing.T) { var pe protobuf.PBError ok = errors.As(err, &pe) test.Assert(t, ok) - test.Assert(t, te.TypeId() == 100) + test.Assert(t, te.TypeID() == 100) // Test PbError with remote addr err = ClientErrorHandlerWithAddr(reqCtx, protobuf.NewPbError(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote-"+tcpAddrStr+"]: mock", err.Error()) ok = errors.As(err, &pe) test.Assert(t, ok) - test.Assert(t, te.TypeId() == 100) + test.Assert(t, te.TypeID() == 100) // Test status.Error err = DefaultClientErrorHandler(context.Background(), status.Err(100, "mock")) diff --git a/client/mocks_test.go b/client/mocks_test.go index c3adc5eea9..4064c70703 100644 --- a/client/mocks_test.go +++ b/client/mocks_test.go @@ -17,7 +17,7 @@ package client import ( - "github.com/apache/thrift/lib/go/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // MockTStruct implements the thrift.TStruct interface. diff --git a/pkg/protocol/bthrift/exception.go b/pkg/protocol/bthrift/exception.go index d229dd9eb0..9854520c87 100644 --- a/pkg/protocol/bthrift/exception.go +++ b/pkg/protocol/bthrift/exception.go @@ -29,6 +29,9 @@ type ApplicationException struct { m string } +// check interface only. TO BE REMOVED in the future +var _ thrift.TApplicationException = &ApplicationException{} + // NewApplicationException creates an ApplicationException instance func NewApplicationException(t int32, msg string) *ApplicationException { return &ApplicationException{t: t, m: msg} @@ -40,6 +43,9 @@ func (e *ApplicationException) Msg() string { return e.m } // TypeID ... func (e *ApplicationException) TypeID() int32 { return e.t } +// TypeId ... for apache ApplicationException compatibility +func (e *ApplicationException) TypeId() int32 { return e.t } + // BLength returns the len of encoded buffer. func (e *ApplicationException) BLength() int { // Msg Field: 1 (type) + 2 (id) + 4(strlen) + len(m) @@ -48,7 +54,7 @@ func (e *ApplicationException) BLength() int { return (1 + 2 + 4 + len(e.m)) + (1 + 2 + 4) + 1 } -// Read ... +// FastRead ... func (e *ApplicationException) FastRead(b []byte) (off int, err error) { for i := 0; i < 2; i++ { _, tp, id, l, err := Binary.ReadFieldBegin(b[off:]) @@ -80,7 +86,7 @@ func (e *ApplicationException) FastRead(b []byte) (off int, err error) { return off, nil } -// Write ... +// FastWrite ... func (e *ApplicationException) FastWrite(b []byte) (off int) { off += Binary.WriteFieldBegin(b[off:], "", thrift.STRING, 1) off += Binary.WriteString(b[off:], e.m) @@ -95,6 +101,57 @@ func (e *ApplicationException) FastWriteNocopy(b []byte, binaryWriter BinaryWrit return e.FastWrite(b) } +// Read implements Read interface of TStruct +// it only supports binary protocol. +// Deprecated: use FastRead instead +func (e *ApplicationException) Read(in thrift.TProtocol) error { + for { + _, ttype, id, err := in.ReadFieldBegin() + if err != nil { + return err + } + if ttype == thrift.STOP { + break + } + switch { + case id == 1 && ttype == thrift.STRING: + e.m, err = in.ReadString() + if err != nil { + return err + } + case id == 2 && ttype == thrift.I32: + e.t, err = in.ReadI32() + if err != nil { + return err + } + default: + if err = thrift.SkipDefaultDepth(in, ttype); err != nil { + return err + } + } + } + return nil +} + +// Write implements Write interface of TStruct +// it only supports binary protocol. +// Deprecated: use FastWrite instead +func (e *ApplicationException) Write(out thrift.TProtocol) error { + if err := out.WriteFieldBegin("message", thrift.STRING, 1); err != nil { + return err + } + if err := out.WriteString(e.m); err != nil { + return err + } + if err := out.WriteFieldBegin("type", thrift.I32, 2); err != nil { + return err + } + if err := out.WriteI32(e.t); err != nil { + return err + } + return out.WriteFieldStop() +} + // originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go var defaultApplicationExceptionMessage = map[int32]string{ thrift.UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception", diff --git a/pkg/protocol/bthrift/exception_test.go b/pkg/protocol/bthrift/exception_test.go index 25c25e4cbe..9a625c9266 100644 --- a/pkg/protocol/bthrift/exception_test.go +++ b/pkg/protocol/bthrift/exception_test.go @@ -37,11 +37,27 @@ func TestApplicationException(t *testing.T) { test.Assert(t, ex2.TypeID() == 1) test.Assert(t, ex2.Msg() == "t1") - // compatibility test only, can be removed in the future + // ================= + // the code below, it's for compatibility test only. + // it can be removed in the future along with Read/Write method + trans := thrift.NewTMemoryBufferLen(100) proto := thrift.NewTBinaryProtocol(trans, true, true) - ex0 := thrift.NewTApplicationException(1, "t1") - err = ex0.Write(proto) + ex9 := thrift.NewTApplicationException(1, "t1") + err = ex9.Write(proto) + test.Assert(t, err == nil, err) + test.Assert(t, bytes.Equal(b, trans.Bytes())) + + trans = thrift.NewTMemoryBufferLen(100) + proto = thrift.NewTBinaryProtocol(trans, true, true) + ex3 := NewApplicationException(1, "t1") + err = ex3.Write(proto) test.Assert(t, err == nil, err) test.Assert(t, bytes.Equal(b, trans.Bytes())) + + ex4 := NewApplicationException(0, "") + err = ex4.Read(proto) + test.Assert(t, err == nil, err) + test.Assert(t, ex4.TypeID() == 1) + test.Assert(t, ex4.Msg() == "t1") } diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index fe7ffe68c4..0d65f224c7 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -276,11 +276,11 @@ func getValidData(methodName string, message remote.Message) (interface{}, error transErr, isTransErr := data.(*remote.TransError) if !isTransErr { if err, isError := data.(error); isError { - encodeErr := thrift.NewTApplicationException(remote.InternalError, err.Error()) + encodeErr := bthrift.NewApplicationException(remote.InternalError, err.Error()) return encodeErr, nil } return nil, errors.New("exception relay need error type data") } - encodeErr := thrift.NewTApplicationException(transErr.TypeID(), transErr.Error()) + encodeErr := bthrift.NewApplicationException(transErr.TypeID(), transErr.Error()) return encodeErr, nil } diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 30b707f4da..548b634a68 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -110,7 +110,7 @@ func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data in // UnmarshalThriftException decode thrift exception from tProt // If your input is []byte, you can wrap it with `NewBinaryProtocol(remote.NewReaderBuffer(buf))` func UnmarshalThriftException(tProt thrift.TProtocol) error { - exception := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") + exception := bthrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") if err := exception.Read(tProt); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal Exception failed: %s", err.Error())) } diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index b5fe3dcd43..47f9ce5dcd 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -24,6 +24,7 @@ import ( "github.com/cloudwego/kitex/internal/mocks/thrift/fast" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" ) @@ -163,7 +164,7 @@ func TestUnmarshalThriftException(t *testing.T) { transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize) tProt := thrift.NewTBinaryProtocol(transport, true, true) errMessage := "test: invalid protocol" - exc := thrift.NewTApplicationException(thrift.INVALID_PROTOCOL, errMessage) + exc := bthrift.NewApplicationException(thrift.INVALID_PROTOCOL, errMessage) err := exc.Write(tProt) test.Assert(t, err == nil, err) diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 5a12d7b0df..d7bef3b027 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/kitex/internal/mocks" mt "github.com/cloudwego/kitex/internal/mocks/thrift/fast" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" @@ -209,13 +210,13 @@ func TestException(t *testing.T) { func TestTransErrorUnwrap(t *testing.T) { errMsg := "mock err" - transErr := remote.NewTransError(remote.InternalError, thrift.NewTApplicationException(1000, errMsg)) - uwErr, ok := transErr.Unwrap().(thrift.TApplicationException) + transErr := remote.NewTransError(remote.InternalError, bthrift.NewApplicationException(1000, errMsg)) + uwErr, ok := transErr.Unwrap().(*bthrift.ApplicationException) test.Assert(t, ok) test.Assert(t, uwErr.TypeId() == 1000) test.Assert(t, transErr.Error() == errMsg) - uwErr2, ok := errors.Unwrap(transErr).(thrift.TApplicationException) + uwErr2, ok := errors.Unwrap(transErr).(*bthrift.ApplicationException) test.Assert(t, ok) test.Assert(t, uwErr2.TypeId() == 1000) test.Assert(t, uwErr2.Error() == errMsg) diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index 8741aea9dc..f05f1d4612 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -75,7 +75,7 @@ func (t *ThriftMessageCodec) Decode(b []byte, msg thrift.TStruct) (method string return } if msgType == thrift.EXCEPTION { - exception := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") + exception := bthrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") if err = exception.Read(t.tProt); err != nil { return } @@ -146,6 +146,5 @@ func UnmarshalError(b []byte) error { if _, err := ex.FastRead(b[off:]); err != nil { return err } - // XXX: for compatibility, consider to remove it in the future - return thrift.NewTApplicationException(ex.TypeID(), ex.Msg()) + return ex } diff --git a/server/invoke_test.go b/server/invoke_test.go index c03a2e754b..77089b4389 100644 --- a/server/invoke_test.go +++ b/server/invoke_test.go @@ -22,10 +22,9 @@ import ( "sync/atomic" "testing" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/trans/invoke" "github.com/cloudwego/kitex/pkg/utils" ) From 8acf3658487307eac19d2c440e066819ce6d91f4 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Wed, 3 Jul 2024 12:57:23 +0800 Subject: [PATCH 05/70] feat(generic): support grpc json generic for client (#1411) --- .../genericclient/generic_stream_service.go | 100 ++ .../generic_stream_service_test.go | 57 ++ client/genericclient/stream.go | 219 +++++ .../proto/kitex_gen/pbapi/mock/client.go | 157 ++++ .../proto/kitex_gen/pbapi/mock/invoker.go | 39 + .../mocks/proto/kitex_gen/pbapi/mock/mock.go | 857 ++++++++++++++++++ .../kitex_gen/pbapi/mock/pbapi.pb.fast.go | 151 +++ .../proto/kitex_gen/pbapi/mock/pbapi.pb.go | 276 ++++++ .../proto/kitex_gen/pbapi/mock/server.go | 39 + internal/mocks/proto/pbapi.proto | 19 + pkg/generic/descriptor/descriptor.go | 3 + pkg/generic/generic.go | 7 +- pkg/generic/generic_service.go | 2 - pkg/generic/generic_service_test.go | 2 +- pkg/generic/grpcjsonpb_test/generic_init.go | 148 +++ pkg/generic/grpcjsonpb_test/generic_test.go | 163 ++++ pkg/generic/grpcjsonpb_test/idl/pbapi.proto | 19 + pkg/generic/httppbthrift_codec.go | 9 +- pkg/generic/httppbthrift_codec_test.go | 2 + pkg/generic/httpthrift_codec.go | 6 +- pkg/generic/httpthrift_codec_test.go | 3 + pkg/generic/jsonpb_codec.go | 20 +- pkg/generic/jsonpb_codec_test.go | 31 +- pkg/generic/jsonthrift_codec.go | 6 +- pkg/generic/jsonthrift_codec_test.go | 5 + pkg/generic/mapthrift_codec.go | 6 +- pkg/generic/mapthrift_codec_test.go | 4 + pkg/generic/thrift/json.go | 2 +- pkg/generic/thrift/parse.go | 24 + pkg/generic/thrift/parse_test.go | 44 + pkg/remote/codec/grpc/grpc.go | 20 + pkg/serviceinfo/serviceinfo.go | 7 + 32 files changed, 2422 insertions(+), 25 deletions(-) create mode 100644 client/genericclient/generic_stream_service.go create mode 100644 client/genericclient/generic_stream_service_test.go create mode 100644 client/genericclient/stream.go create mode 100644 internal/mocks/proto/kitex_gen/pbapi/mock/client.go create mode 100644 internal/mocks/proto/kitex_gen/pbapi/mock/invoker.go create mode 100644 internal/mocks/proto/kitex_gen/pbapi/mock/mock.go create mode 100644 internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.fast.go create mode 100644 internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.go create mode 100644 internal/mocks/proto/kitex_gen/pbapi/mock/server.go create mode 100644 internal/mocks/proto/pbapi.proto create mode 100644 pkg/generic/grpcjsonpb_test/generic_init.go create mode 100644 pkg/generic/grpcjsonpb_test/generic_test.go create mode 100644 pkg/generic/grpcjsonpb_test/idl/pbapi.proto diff --git a/client/genericclient/generic_stream_service.go b/client/genericclient/generic_stream_service.go new file mode 100644 index 0000000000..a3b995a305 --- /dev/null +++ b/client/genericclient/generic_stream_service.go @@ -0,0 +1,100 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genericclient + +import ( + "github.com/cloudwego/kitex/pkg/generic" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo { + readerWriter := g.MessageReaderWriter() + if readerWriter == nil { + // TODO: support grpc binary generic + panic("binary generic streaming is not supported") + } + + methods := map[string]serviceinfo.MethodInfo{ + serviceinfo.GenericClientStreamingMethod: serviceinfo.NewMethodInfo( + nil, + func() interface{} { + args := &generic.Args{} + args.SetCodec(readerWriter) + return args + }, + func() interface{} { + result := &generic.Result{} + result.SetCodec(readerWriter) + return result + }, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), + ), + serviceinfo.GenericServerStreamingMethod: serviceinfo.NewMethodInfo( + nil, + func() interface{} { + args := &generic.Args{} + args.SetCodec(readerWriter) + return args + }, + func() interface{} { + result := &generic.Result{} + result.SetCodec(readerWriter) + return result + }, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), + ), + serviceinfo.GenericBidirectionalStreamingMethod: serviceinfo.NewMethodInfo( + nil, + func() interface{} { + args := &generic.Args{} + args.SetCodec(readerWriter) + return args + }, + func() interface{} { + result := &generic.Result{} + result.SetCodec(readerWriter) + return result + }, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + serviceinfo.GenericMethod: serviceinfo.NewMethodInfo( + nil, + func() interface{} { + args := &generic.Args{} + args.SetCodec(readerWriter) + return args + }, + func() interface{} { + result := &generic.Result{} + result.SetCodec(readerWriter) + return result + }, + false, + ), + } + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: g.IDLServiceName(), + Methods: methods, + PayloadCodec: g.PayloadCodecType(), + Extra: make(map[string]interface{}), + } + svcInfo.Extra["generic"] = true + return svcInfo +} diff --git a/client/genericclient/generic_stream_service_test.go b/client/genericclient/generic_stream_service_test.go new file mode 100644 index 0000000000..3fe71f8bd3 --- /dev/null +++ b/client/genericclient/generic_stream_service_test.go @@ -0,0 +1,57 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genericclient + +import ( + "context" + "testing" + + dproto "github.com/cloudwego/dynamicgo/proto" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func TestGenericStreamService(t *testing.T) { + p, err := generic.NewPbFileProviderWithDynamicGo("../../pkg/generic/grpcjsonpb_test/idl/pbapi.proto", context.Background(), dproto.Options{}) + test.Assert(t, err == nil) + g, err := generic.JSONPbGeneric(p) + test.Assert(t, err == nil) + + svcInfo := newClientStreamingServiceInfo(g) + test.Assert(t, svcInfo.Extra["generic"] == true) + svcInfo.GenericMethod = func(name string) serviceinfo.MethodInfo { + return svcInfo.Methods[name] + } + mtInfo := svcInfo.MethodInfo(serviceinfo.GenericClientStreamingMethod) + _, ok := mtInfo.NewArgs().(*generic.Args) + test.Assert(t, ok) + _, ok = mtInfo.NewResult().(*generic.Result) + test.Assert(t, ok) + test.Assert(t, !mtInfo.OneWay()) + test.Assert(t, mtInfo.IsStreaming()) + test.Assert(t, mtInfo.StreamingMode() == serviceinfo.StreamingClient, mtInfo.StreamingMode()) + mtInfo = svcInfo.MethodInfo(serviceinfo.GenericMethod) + test.Assert(t, !mtInfo.IsStreaming()) + + test.PanicAt(t, func() { + newClientStreamingServiceInfo(generic.BinaryThriftGeneric()) + }, func(err interface{}) bool { + return err.(string) == "binary generic streaming is not supported" + }) +} diff --git a/client/genericclient/stream.go b/client/genericclient/stream.go new file mode 100644 index 0000000000..05d260f2a6 --- /dev/null +++ b/client/genericclient/stream.go @@ -0,0 +1,219 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genericclient + +import ( + "context" + "errors" + "fmt" + "runtime" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/callopt" + "github.com/cloudwego/kitex/pkg/generic" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streaming" +) + +type ClientStreaming interface { + streaming.Stream + Send(req interface{}) error + CloseAndRecv() (resp interface{}, err error) +} + +type ServerStreaming interface { + streaming.Stream + Recv() (resp interface{}, err error) +} + +type BidirectionalStreaming interface { + streaming.Stream + Send(req interface{}) error + Recv() (resp interface{}, err error) +} + +func NewStreamingClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { + svcInfo := newClientStreamingServiceInfo(g) + return NewStreamingClientWithServiceInfo(destService, g, svcInfo, opts...) +} + +func NewStreamingClientWithServiceInfo(destService string, g generic.Generic, svcInfo *serviceinfo.ServiceInfo, opts ...client.Option) (Client, error) { + var options []client.Option + options = append(options, client.WithGeneric(g)) + options = append(options, client.WithDestService(destService)) + options = append(options, opts...) + + kc, err := client.NewClient(svcInfo, options...) + if err != nil { + return nil, err + } + cli := &genericServiceClient{ + svcInfo: svcInfo, + kClient: kc, + g: g, + } + runtime.SetFinalizer(cli, (*genericServiceClient).Close) + + svcInfo.GenericMethod = func(name string) serviceinfo.MethodInfo { + n, err := g.GetMethod(nil, name) + var key string + if err != nil { + // TODO: support fallback solution for binary + key = serviceinfo.GenericMethod + } else { + key = getGenericStreamingMethodInfoKey(n.StreamingMode) + } + m := svcInfo.Methods[key] + return &methodInfo{ + MethodInfo: m, + oneway: n.Oneway, + } + } + + return cli, nil +} + +func getGenericStreamingMethodInfoKey(streamingMode serviceinfo.StreamingMode) string { + switch streamingMode { + case serviceinfo.StreamingClient: + return serviceinfo.GenericClientStreamingMethod + case serviceinfo.StreamingServer: + return serviceinfo.GenericServerStreamingMethod + case serviceinfo.StreamingBidirectional: + return serviceinfo.GenericBidirectionalStreamingMethod + default: + return serviceinfo.GenericMethod + } +} + +type clientStreamingClient struct { + streaming.Stream + method string + methodInfo serviceinfo.MethodInfo +} + +func NewClientStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (ClientStreaming, error) { + gCli, ok := genericCli.(*genericServiceClient) + if !ok { + return nil, errors.New("invalid generic client") + } + stream, err := getStream(ctx, genericCli, method, callOpts...) + if err != nil { + return nil, err + } + return &clientStreamingClient{stream, method, gCli.svcInfo.MethodInfo(method)}, nil +} + +func (cs *clientStreamingClient) Send(req interface{}) error { + _args := cs.methodInfo.NewArgs().(*generic.Args) + _args.Method = cs.method + _args.Request = req + return cs.Stream.SendMsg(_args) +} + +func (cs *clientStreamingClient) CloseAndRecv() (resp interface{}, err error) { + if err := cs.Stream.Close(); err != nil { + return nil, err + } + _result := cs.methodInfo.NewResult().(*generic.Result) + if err = cs.Stream.RecvMsg(_result); err != nil { + return nil, err + } + return _result.GetSuccess(), nil +} + +type serverStreamingClient struct { + streaming.Stream + methodInfo serviceinfo.MethodInfo +} + +func NewServerStreaming(ctx context.Context, genericCli Client, method string, req interface{}, callOpts ...callopt.Option) (ServerStreaming, error) { + gCli, ok := genericCli.(*genericServiceClient) + if !ok { + return nil, errors.New("invalid generic client") + } + stream, err := getStream(ctx, genericCli, method, callOpts...) + if err != nil { + return nil, err + } + ss := &serverStreamingClient{stream, gCli.svcInfo.MethodInfo(method)} + _args := gCli.svcInfo.MethodInfo(method).NewArgs().(*generic.Args) + _args.Method = method + _args.Request = req + if err = ss.Stream.SendMsg(_args); err != nil { + return nil, err + } + if err = ss.Stream.Close(); err != nil { + return nil, err + } + return ss, nil +} + +func (ss *serverStreamingClient) Recv() (resp interface{}, err error) { + _result := ss.methodInfo.NewResult().(*generic.Result) + if err = ss.Stream.RecvMsg(_result); err != nil { + return nil, err + } + return _result.GetSuccess(), nil +} + +type bidirectionalStreamingClient struct { + streaming.Stream + method string + methodInfo serviceinfo.MethodInfo +} + +func NewBidirectionalStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (BidirectionalStreaming, error) { + gCli, ok := genericCli.(*genericServiceClient) + if !ok { + return nil, errors.New("invalid generic client") + } + stream, err := getStream(ctx, genericCli, method, callOpts...) + if err != nil { + return nil, err + } + return &bidirectionalStreamingClient{stream, method, gCli.svcInfo.MethodInfo(method)}, nil +} + +func (bs *bidirectionalStreamingClient) Send(req interface{}) error { + _args := bs.methodInfo.NewArgs().(*generic.Args) + _args.Method = bs.method + _args.Request = req + return bs.Stream.SendMsg(_args) +} + +func (bs *bidirectionalStreamingClient) Recv() (resp interface{}, err error) { + _result := bs.methodInfo.NewResult().(*generic.Result) + if err = bs.Stream.RecvMsg(_result); err != nil { + return nil, err + } + return _result.GetSuccess(), nil +} + +func getStream(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (streaming.Stream, error) { + ctx = client.NewCtxWithCallOptions(ctx, callOpts) + streamClient, ok := genericCli.(*genericServiceClient).kClient.(client.Streaming) + if !ok { + return nil, fmt.Errorf("client not support streaming") + } + res := new(streaming.Result) + err := streamClient.Stream(ctx, method, nil, res) + if err != nil { + return nil, err + } + return res.Stream, nil +} diff --git a/internal/mocks/proto/kitex_gen/pbapi/mock/client.go b/internal/mocks/proto/kitex_gen/pbapi/mock/client.go new file mode 100644 index 0000000000..6957af77ec --- /dev/null +++ b/internal/mocks/proto/kitex_gen/pbapi/mock/client.go @@ -0,0 +1,157 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by Kitex v0.9.1. DO NOT EDIT. + +package mock + +import ( + "context" + client "github.com/cloudwego/kitex/client" + callopt "github.com/cloudwego/kitex/client/callopt" + "github.com/cloudwego/kitex/client/callopt/streamcall" + "github.com/cloudwego/kitex/client/streamclient" + streaming "github.com/cloudwego/kitex/pkg/streaming" + transport "github.com/cloudwego/kitex/transport" +) + +// Client is designed to provide IDL-compatible methods with call-option parameter for kitex framework. +type Client interface { + UnaryTest(ctx context.Context, Req *MockReq, callOptions ...callopt.Option) (r *MockResp, err error) + ClientStreamingTest(ctx context.Context, callOptions ...callopt.Option) (stream Mock_ClientStreamingTestClient, err error) + ServerStreamingTest(ctx context.Context, Req *MockReq, callOptions ...callopt.Option) (stream Mock_ServerStreamingTestClient, err error) + BidirectionalStreamingTest(ctx context.Context, callOptions ...callopt.Option) (stream Mock_BidirectionalStreamingTestClient, err error) +} + +// StreamClient is designed to provide Interface for Streaming APIs. +type StreamClient interface { + ClientStreamingTest(ctx context.Context, callOptions ...streamcall.Option) (stream Mock_ClientStreamingTestClient, err error) + ServerStreamingTest(ctx context.Context, Req *MockReq, callOptions ...streamcall.Option) (stream Mock_ServerStreamingTestClient, err error) + BidirectionalStreamingTest(ctx context.Context, callOptions ...streamcall.Option) (stream Mock_BidirectionalStreamingTestClient, err error) +} + +type Mock_ClientStreamingTestClient interface { + streaming.Stream + Send(*MockReq) error + CloseAndRecv() (*MockResp, error) +} + +type Mock_ServerStreamingTestClient interface { + streaming.Stream + Recv() (*MockResp, error) +} + +type Mock_BidirectionalStreamingTestClient interface { + streaming.Stream + Send(*MockReq) error + Recv() (*MockResp, error) +} + +// NewClient creates a client for the service defined in IDL. +func NewClient(destService string, opts ...client.Option) (Client, error) { + var options []client.Option + options = append(options, client.WithDestService(destService)) + + options = append(options, client.WithTransportProtocol(transport.GRPC)) + + options = append(options, opts...) + + kc, err := client.NewClient(serviceInfo(), options...) + if err != nil { + return nil, err + } + return &kMockClient{ + kClient: newServiceClient(kc), + }, nil +} + +// MustNewClient creates a client for the service defined in IDL. It panics if any error occurs. +func MustNewClient(destService string, opts ...client.Option) Client { + kc, err := NewClient(destService, opts...) + if err != nil { + panic(err) + } + return kc +} + +type kMockClient struct { + *kClient +} + +func (p *kMockClient) UnaryTest(ctx context.Context, Req *MockReq, callOptions ...callopt.Option) (r *MockResp, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.UnaryTest(ctx, Req) +} + +func (p *kMockClient) ClientStreamingTest(ctx context.Context, callOptions ...callopt.Option) (stream Mock_ClientStreamingTestClient, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ClientStreamingTest(ctx) +} + +func (p *kMockClient) ServerStreamingTest(ctx context.Context, Req *MockReq, callOptions ...callopt.Option) (stream Mock_ServerStreamingTestClient, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.ServerStreamingTest(ctx, Req) +} + +func (p *kMockClient) BidirectionalStreamingTest(ctx context.Context, callOptions ...callopt.Option) (stream Mock_BidirectionalStreamingTestClient, err error) { + ctx = client.NewCtxWithCallOptions(ctx, callOptions) + return p.kClient.BidirectionalStreamingTest(ctx) +} + +// NewStreamClient creates a stream client for the service's streaming APIs defined in IDL. +func NewStreamClient(destService string, opts ...streamclient.Option) (StreamClient, error) { + var options []client.Option + options = append(options, client.WithDestService(destService)) + options = append(options, client.WithTransportProtocol(transport.GRPC)) + options = append(options, streamclient.GetClientOptions(opts)...) + + kc, err := client.NewClient(serviceInfoForStreamClient(), options...) + if err != nil { + return nil, err + } + return &kMockStreamClient{ + kClient: newServiceClient(kc), + }, nil +} + +// MustNewStreamClient creates a stream client for the service's streaming APIs defined in IDL. +// It panics if any error occurs. +func MustNewStreamClient(destService string, opts ...streamclient.Option) StreamClient { + kc, err := NewStreamClient(destService, opts...) + if err != nil { + panic(err) + } + return kc +} + +type kMockStreamClient struct { + *kClient +} + +func (p *kMockStreamClient) ClientStreamingTest(ctx context.Context, callOptions ...streamcall.Option) (stream Mock_ClientStreamingTestClient, err error) { + ctx = client.NewCtxWithCallOptions(ctx, streamcall.GetCallOptions(callOptions)) + return p.kClient.ClientStreamingTest(ctx) +} + +func (p *kMockStreamClient) ServerStreamingTest(ctx context.Context, Req *MockReq, callOptions ...streamcall.Option) (stream Mock_ServerStreamingTestClient, err error) { + ctx = client.NewCtxWithCallOptions(ctx, streamcall.GetCallOptions(callOptions)) + return p.kClient.ServerStreamingTest(ctx, Req) +} + +func (p *kMockStreamClient) BidirectionalStreamingTest(ctx context.Context, callOptions ...streamcall.Option) (stream Mock_BidirectionalStreamingTestClient, err error) { + ctx = client.NewCtxWithCallOptions(ctx, streamcall.GetCallOptions(callOptions)) + return p.kClient.BidirectionalStreamingTest(ctx) +} diff --git a/internal/mocks/proto/kitex_gen/pbapi/mock/invoker.go b/internal/mocks/proto/kitex_gen/pbapi/mock/invoker.go new file mode 100644 index 0000000000..61ad657eff --- /dev/null +++ b/internal/mocks/proto/kitex_gen/pbapi/mock/invoker.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by Kitex v0.9.1. DO NOT EDIT. + +package mock + +import ( + server "github.com/cloudwego/kitex/server" +) + +// NewInvoker creates a server.Invoker with the given handler and options. +func NewInvoker(handler Mock, opts ...server.Option) server.Invoker { + var options []server.Option + + options = append(options, opts...) + + s := server.NewInvoker(options...) + if err := s.RegisterService(serviceInfo(), handler); err != nil { + panic(err) + } + if err := s.Init(); err != nil { + panic(err) + } + return s +} diff --git a/internal/mocks/proto/kitex_gen/pbapi/mock/mock.go b/internal/mocks/proto/kitex_gen/pbapi/mock/mock.go new file mode 100644 index 0000000000..51139aa393 --- /dev/null +++ b/internal/mocks/proto/kitex_gen/pbapi/mock/mock.go @@ -0,0 +1,857 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by Kitex v0.9.1. DO NOT EDIT. + +package mock + +import ( + "context" + "errors" + "fmt" + client "github.com/cloudwego/kitex/client" + kitex "github.com/cloudwego/kitex/pkg/serviceinfo" + streaming "github.com/cloudwego/kitex/pkg/streaming" + proto "google.golang.org/protobuf/proto" +) + +var errInvalidMessageType = errors.New("invalid message type for service method handler") + +var serviceMethods = map[string]kitex.MethodInfo{ + "UnaryTest": kitex.NewMethodInfo( + unaryTestHandler, + newUnaryTestArgs, + newUnaryTestResult, + false, + kitex.WithStreamingMode(kitex.StreamingUnary), + ), + "ClientStreamingTest": kitex.NewMethodInfo( + clientStreamingTestHandler, + newClientStreamingTestArgs, + newClientStreamingTestResult, + false, + kitex.WithStreamingMode(kitex.StreamingClient), + ), + "ServerStreamingTest": kitex.NewMethodInfo( + serverStreamingTestHandler, + newServerStreamingTestArgs, + newServerStreamingTestResult, + false, + kitex.WithStreamingMode(kitex.StreamingServer), + ), + "BidirectionalStreamingTest": kitex.NewMethodInfo( + bidirectionalStreamingTestHandler, + newBidirectionalStreamingTestArgs, + newBidirectionalStreamingTestResult, + false, + kitex.WithStreamingMode(kitex.StreamingBidirectional), + ), +} + +var ( + mockServiceInfo = NewServiceInfo() + mockServiceInfoForClient = NewServiceInfoForClient() + mockServiceInfoForStreamClient = NewServiceInfoForStreamClient() +) + +// for server +func serviceInfo() *kitex.ServiceInfo { + return mockServiceInfo +} + +// for client +func serviceInfoForStreamClient() *kitex.ServiceInfo { + return mockServiceInfoForStreamClient +} + +// for stream client +func serviceInfoForClient() *kitex.ServiceInfo { + return mockServiceInfoForClient +} + +// NewServiceInfo creates a new ServiceInfo containing all methods +func NewServiceInfo() *kitex.ServiceInfo { + return newServiceInfo(true, true, true) +} + +// NewServiceInfo creates a new ServiceInfo containing non-streaming methods +func NewServiceInfoForClient() *kitex.ServiceInfo { + return newServiceInfo(false, false, true) +} +func NewServiceInfoForStreamClient() *kitex.ServiceInfo { + return newServiceInfo(true, true, false) +} + +func newServiceInfo(hasStreaming bool, keepStreamingMethods bool, keepNonStreamingMethods bool) *kitex.ServiceInfo { + serviceName := "Mock" + handlerType := (*Mock)(nil) + methods := map[string]kitex.MethodInfo{} + for name, m := range serviceMethods { + if m.IsStreaming() && !keepStreamingMethods { + continue + } + if !m.IsStreaming() && !keepNonStreamingMethods { + continue + } + methods[name] = m + } + extra := map[string]interface{}{ + "PackageName": "pbapi", + } + if hasStreaming { + extra["streaming"] = hasStreaming + } + svcInfo := &kitex.ServiceInfo{ + ServiceName: serviceName, + HandlerType: handlerType, + Methods: methods, + PayloadCodec: kitex.Protobuf, + KiteXGenVersion: "v0.9.1", + Extra: extra, + } + return svcInfo +} + +func unaryTestHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + switch s := arg.(type) { + case *streaming.Args: + st := s.Stream + req := new(MockReq) + if err := st.RecvMsg(req); err != nil { + return err + } + resp, err := handler.(Mock).UnaryTest(ctx, req) + if err != nil { + return err + } + return st.SendMsg(resp) + case *UnaryTestArgs: + success, err := handler.(Mock).UnaryTest(ctx, s.Req) + if err != nil { + return err + } + realResult := result.(*UnaryTestResult) + realResult.Success = success + return nil + default: + return errInvalidMessageType + } +} +func newUnaryTestArgs() interface{} { + return &UnaryTestArgs{} +} + +func newUnaryTestResult() interface{} { + return &UnaryTestResult{} +} + +type UnaryTestArgs struct { + Req *MockReq +} + +func (p *UnaryTestArgs) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetReq() { + p.Req = new(MockReq) + } + return p.Req.FastRead(buf, _type, number) +} + +func (p *UnaryTestArgs) FastWrite(buf []byte) (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.FastWrite(buf) +} + +func (p *UnaryTestArgs) Size() (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.Size() +} + +func (p *UnaryTestArgs) Marshal(out []byte) ([]byte, error) { + if !p.IsSetReq() { + return out, nil + } + return proto.Marshal(p.Req) +} + +func (p *UnaryTestArgs) Unmarshal(in []byte) error { + msg := new(MockReq) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Req = msg + return nil +} + +var UnaryTestArgs_Req_DEFAULT *MockReq + +func (p *UnaryTestArgs) GetReq() *MockReq { + if !p.IsSetReq() { + return UnaryTestArgs_Req_DEFAULT + } + return p.Req +} + +func (p *UnaryTestArgs) IsSetReq() bool { + return p.Req != nil +} + +func (p *UnaryTestArgs) GetFirstArgument() interface{} { + return p.Req +} + +type UnaryTestResult struct { + Success *MockResp +} + +var UnaryTestResult_Success_DEFAULT *MockResp + +func (p *UnaryTestResult) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetSuccess() { + p.Success = new(MockResp) + } + return p.Success.FastRead(buf, _type, number) +} + +func (p *UnaryTestResult) FastWrite(buf []byte) (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.FastWrite(buf) +} + +func (p *UnaryTestResult) Size() (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.Size() +} + +func (p *UnaryTestResult) Marshal(out []byte) ([]byte, error) { + if !p.IsSetSuccess() { + return out, nil + } + return proto.Marshal(p.Success) +} + +func (p *UnaryTestResult) Unmarshal(in []byte) error { + msg := new(MockResp) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Success = msg + return nil +} + +func (p *UnaryTestResult) GetSuccess() *MockResp { + if !p.IsSetSuccess() { + return UnaryTestResult_Success_DEFAULT + } + return p.Success +} + +func (p *UnaryTestResult) SetSuccess(x interface{}) { + p.Success = x.(*MockResp) +} + +func (p *UnaryTestResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *UnaryTestResult) GetResult() interface{} { + return p.Success +} + +func clientStreamingTestHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + streamingArgs, ok := arg.(*streaming.Args) + if !ok { + return errInvalidMessageType + } + st := streamingArgs.Stream + stream := &mockClientStreamingTestServer{st} + return handler.(Mock).ClientStreamingTest(stream) +} + +type mockClientStreamingTestClient struct { + streaming.Stream +} + +func (x *mockClientStreamingTestClient) DoFinish(err error) { + if finisher, ok := x.Stream.(streaming.WithDoFinish); ok { + finisher.DoFinish(err) + } else { + panic(fmt.Sprintf("streaming.WithDoFinish is not implemented by %T", x.Stream)) + } +} +func (x *mockClientStreamingTestClient) Send(m *MockReq) error { + return x.Stream.SendMsg(m) +} +func (x *mockClientStreamingTestClient) CloseAndRecv() (*MockResp, error) { + if err := x.Stream.Close(); err != nil { + return nil, err + } + m := new(MockResp) + return m, x.Stream.RecvMsg(m) +} + +type mockClientStreamingTestServer struct { + streaming.Stream +} + +func (x *mockClientStreamingTestServer) SendAndClose(m *MockResp) error { + return x.Stream.SendMsg(m) +} + +func (x *mockClientStreamingTestServer) Recv() (*MockReq, error) { + m := new(MockReq) + return m, x.Stream.RecvMsg(m) +} + +func newClientStreamingTestArgs() interface{} { + return &ClientStreamingTestArgs{} +} + +func newClientStreamingTestResult() interface{} { + return &ClientStreamingTestResult{} +} + +type ClientStreamingTestArgs struct { + Req *MockReq +} + +func (p *ClientStreamingTestArgs) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetReq() { + p.Req = new(MockReq) + } + return p.Req.FastRead(buf, _type, number) +} + +func (p *ClientStreamingTestArgs) FastWrite(buf []byte) (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.FastWrite(buf) +} + +func (p *ClientStreamingTestArgs) Size() (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.Size() +} + +func (p *ClientStreamingTestArgs) Marshal(out []byte) ([]byte, error) { + if !p.IsSetReq() { + return out, nil + } + return proto.Marshal(p.Req) +} + +func (p *ClientStreamingTestArgs) Unmarshal(in []byte) error { + msg := new(MockReq) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Req = msg + return nil +} + +var ClientStreamingTestArgs_Req_DEFAULT *MockReq + +func (p *ClientStreamingTestArgs) GetReq() *MockReq { + if !p.IsSetReq() { + return ClientStreamingTestArgs_Req_DEFAULT + } + return p.Req +} + +func (p *ClientStreamingTestArgs) IsSetReq() bool { + return p.Req != nil +} + +func (p *ClientStreamingTestArgs) GetFirstArgument() interface{} { + return p.Req +} + +type ClientStreamingTestResult struct { + Success *MockResp +} + +var ClientStreamingTestResult_Success_DEFAULT *MockResp + +func (p *ClientStreamingTestResult) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetSuccess() { + p.Success = new(MockResp) + } + return p.Success.FastRead(buf, _type, number) +} + +func (p *ClientStreamingTestResult) FastWrite(buf []byte) (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.FastWrite(buf) +} + +func (p *ClientStreamingTestResult) Size() (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.Size() +} + +func (p *ClientStreamingTestResult) Marshal(out []byte) ([]byte, error) { + if !p.IsSetSuccess() { + return out, nil + } + return proto.Marshal(p.Success) +} + +func (p *ClientStreamingTestResult) Unmarshal(in []byte) error { + msg := new(MockResp) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Success = msg + return nil +} + +func (p *ClientStreamingTestResult) GetSuccess() *MockResp { + if !p.IsSetSuccess() { + return ClientStreamingTestResult_Success_DEFAULT + } + return p.Success +} + +func (p *ClientStreamingTestResult) SetSuccess(x interface{}) { + p.Success = x.(*MockResp) +} + +func (p *ClientStreamingTestResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *ClientStreamingTestResult) GetResult() interface{} { + return p.Success +} + +func serverStreamingTestHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + streamingArgs, ok := arg.(*streaming.Args) + if !ok { + return errInvalidMessageType + } + st := streamingArgs.Stream + stream := &mockServerStreamingTestServer{st} + req := new(MockReq) + if err := st.RecvMsg(req); err != nil { + return err + } + return handler.(Mock).ServerStreamingTest(req, stream) +} + +type mockServerStreamingTestClient struct { + streaming.Stream +} + +func (x *mockServerStreamingTestClient) DoFinish(err error) { + if finisher, ok := x.Stream.(streaming.WithDoFinish); ok { + finisher.DoFinish(err) + } else { + panic(fmt.Sprintf("streaming.WithDoFinish is not implemented by %T", x.Stream)) + } +} +func (x *mockServerStreamingTestClient) Recv() (*MockResp, error) { + m := new(MockResp) + return m, x.Stream.RecvMsg(m) +} + +type mockServerStreamingTestServer struct { + streaming.Stream +} + +func (x *mockServerStreamingTestServer) Send(m *MockResp) error { + return x.Stream.SendMsg(m) +} + +func newServerStreamingTestArgs() interface{} { + return &ServerStreamingTestArgs{} +} + +func newServerStreamingTestResult() interface{} { + return &ServerStreamingTestResult{} +} + +type ServerStreamingTestArgs struct { + Req *MockReq +} + +func (p *ServerStreamingTestArgs) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetReq() { + p.Req = new(MockReq) + } + return p.Req.FastRead(buf, _type, number) +} + +func (p *ServerStreamingTestArgs) FastWrite(buf []byte) (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.FastWrite(buf) +} + +func (p *ServerStreamingTestArgs) Size() (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.Size() +} + +func (p *ServerStreamingTestArgs) Marshal(out []byte) ([]byte, error) { + if !p.IsSetReq() { + return out, nil + } + return proto.Marshal(p.Req) +} + +func (p *ServerStreamingTestArgs) Unmarshal(in []byte) error { + msg := new(MockReq) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Req = msg + return nil +} + +var ServerStreamingTestArgs_Req_DEFAULT *MockReq + +func (p *ServerStreamingTestArgs) GetReq() *MockReq { + if !p.IsSetReq() { + return ServerStreamingTestArgs_Req_DEFAULT + } + return p.Req +} + +func (p *ServerStreamingTestArgs) IsSetReq() bool { + return p.Req != nil +} + +func (p *ServerStreamingTestArgs) GetFirstArgument() interface{} { + return p.Req +} + +type ServerStreamingTestResult struct { + Success *MockResp +} + +var ServerStreamingTestResult_Success_DEFAULT *MockResp + +func (p *ServerStreamingTestResult) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetSuccess() { + p.Success = new(MockResp) + } + return p.Success.FastRead(buf, _type, number) +} + +func (p *ServerStreamingTestResult) FastWrite(buf []byte) (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.FastWrite(buf) +} + +func (p *ServerStreamingTestResult) Size() (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.Size() +} + +func (p *ServerStreamingTestResult) Marshal(out []byte) ([]byte, error) { + if !p.IsSetSuccess() { + return out, nil + } + return proto.Marshal(p.Success) +} + +func (p *ServerStreamingTestResult) Unmarshal(in []byte) error { + msg := new(MockResp) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Success = msg + return nil +} + +func (p *ServerStreamingTestResult) GetSuccess() *MockResp { + if !p.IsSetSuccess() { + return ServerStreamingTestResult_Success_DEFAULT + } + return p.Success +} + +func (p *ServerStreamingTestResult) SetSuccess(x interface{}) { + p.Success = x.(*MockResp) +} + +func (p *ServerStreamingTestResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *ServerStreamingTestResult) GetResult() interface{} { + return p.Success +} + +func bidirectionalStreamingTestHandler(ctx context.Context, handler interface{}, arg, result interface{}) error { + streamingArgs, ok := arg.(*streaming.Args) + if !ok { + return errInvalidMessageType + } + st := streamingArgs.Stream + stream := &mockBidirectionalStreamingTestServer{st} + return handler.(Mock).BidirectionalStreamingTest(stream) +} + +type mockBidirectionalStreamingTestClient struct { + streaming.Stream +} + +func (x *mockBidirectionalStreamingTestClient) DoFinish(err error) { + if finisher, ok := x.Stream.(streaming.WithDoFinish); ok { + finisher.DoFinish(err) + } else { + panic(fmt.Sprintf("streaming.WithDoFinish is not implemented by %T", x.Stream)) + } +} +func (x *mockBidirectionalStreamingTestClient) Send(m *MockReq) error { + return x.Stream.SendMsg(m) +} +func (x *mockBidirectionalStreamingTestClient) Recv() (*MockResp, error) { + m := new(MockResp) + return m, x.Stream.RecvMsg(m) +} + +type mockBidirectionalStreamingTestServer struct { + streaming.Stream +} + +func (x *mockBidirectionalStreamingTestServer) Send(m *MockResp) error { + return x.Stream.SendMsg(m) +} + +func (x *mockBidirectionalStreamingTestServer) Recv() (*MockReq, error) { + m := new(MockReq) + return m, x.Stream.RecvMsg(m) +} + +func newBidirectionalStreamingTestArgs() interface{} { + return &BidirectionalStreamingTestArgs{} +} + +func newBidirectionalStreamingTestResult() interface{} { + return &BidirectionalStreamingTestResult{} +} + +type BidirectionalStreamingTestArgs struct { + Req *MockReq +} + +func (p *BidirectionalStreamingTestArgs) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetReq() { + p.Req = new(MockReq) + } + return p.Req.FastRead(buf, _type, number) +} + +func (p *BidirectionalStreamingTestArgs) FastWrite(buf []byte) (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.FastWrite(buf) +} + +func (p *BidirectionalStreamingTestArgs) Size() (n int) { + if !p.IsSetReq() { + return 0 + } + return p.Req.Size() +} + +func (p *BidirectionalStreamingTestArgs) Marshal(out []byte) ([]byte, error) { + if !p.IsSetReq() { + return out, nil + } + return proto.Marshal(p.Req) +} + +func (p *BidirectionalStreamingTestArgs) Unmarshal(in []byte) error { + msg := new(MockReq) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Req = msg + return nil +} + +var BidirectionalStreamingTestArgs_Req_DEFAULT *MockReq + +func (p *BidirectionalStreamingTestArgs) GetReq() *MockReq { + if !p.IsSetReq() { + return BidirectionalStreamingTestArgs_Req_DEFAULT + } + return p.Req +} + +func (p *BidirectionalStreamingTestArgs) IsSetReq() bool { + return p.Req != nil +} + +func (p *BidirectionalStreamingTestArgs) GetFirstArgument() interface{} { + return p.Req +} + +type BidirectionalStreamingTestResult struct { + Success *MockResp +} + +var BidirectionalStreamingTestResult_Success_DEFAULT *MockResp + +func (p *BidirectionalStreamingTestResult) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + if !p.IsSetSuccess() { + p.Success = new(MockResp) + } + return p.Success.FastRead(buf, _type, number) +} + +func (p *BidirectionalStreamingTestResult) FastWrite(buf []byte) (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.FastWrite(buf) +} + +func (p *BidirectionalStreamingTestResult) Size() (n int) { + if !p.IsSetSuccess() { + return 0 + } + return p.Success.Size() +} + +func (p *BidirectionalStreamingTestResult) Marshal(out []byte) ([]byte, error) { + if !p.IsSetSuccess() { + return out, nil + } + return proto.Marshal(p.Success) +} + +func (p *BidirectionalStreamingTestResult) Unmarshal(in []byte) error { + msg := new(MockResp) + if err := proto.Unmarshal(in, msg); err != nil { + return err + } + p.Success = msg + return nil +} + +func (p *BidirectionalStreamingTestResult) GetSuccess() *MockResp { + if !p.IsSetSuccess() { + return BidirectionalStreamingTestResult_Success_DEFAULT + } + return p.Success +} + +func (p *BidirectionalStreamingTestResult) SetSuccess(x interface{}) { + p.Success = x.(*MockResp) +} + +func (p *BidirectionalStreamingTestResult) IsSetSuccess() bool { + return p.Success != nil +} + +func (p *BidirectionalStreamingTestResult) GetResult() interface{} { + return p.Success +} + +type kClient struct { + c client.Client +} + +func newServiceClient(c client.Client) *kClient { + return &kClient{ + c: c, + } +} + +func (p *kClient) UnaryTest(ctx context.Context, Req *MockReq) (r *MockResp, err error) { + var _args UnaryTestArgs + _args.Req = Req + var _result UnaryTestResult + if err = p.c.Call(ctx, "UnaryTest", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + +func (p *kClient) ClientStreamingTest(ctx context.Context) (Mock_ClientStreamingTestClient, error) { + streamClient, ok := p.c.(client.Streaming) + if !ok { + return nil, fmt.Errorf("client not support streaming") + } + res := new(streaming.Result) + err := streamClient.Stream(ctx, "ClientStreamingTest", nil, res) + if err != nil { + return nil, err + } + stream := &mockClientStreamingTestClient{res.Stream} + return stream, nil +} + +func (p *kClient) ServerStreamingTest(ctx context.Context, req *MockReq) (Mock_ServerStreamingTestClient, error) { + streamClient, ok := p.c.(client.Streaming) + if !ok { + return nil, fmt.Errorf("client not support streaming") + } + res := new(streaming.Result) + err := streamClient.Stream(ctx, "ServerStreamingTest", nil, res) + if err != nil { + return nil, err + } + stream := &mockServerStreamingTestClient{res.Stream} + + if err := stream.Stream.SendMsg(req); err != nil { + return nil, err + } + if err := stream.Stream.Close(); err != nil { + return nil, err + } + return stream, nil +} + +func (p *kClient) BidirectionalStreamingTest(ctx context.Context) (Mock_BidirectionalStreamingTestClient, error) { + streamClient, ok := p.c.(client.Streaming) + if !ok { + return nil, fmt.Errorf("client not support streaming") + } + res := new(streaming.Result) + err := streamClient.Stream(ctx, "BidirectionalStreamingTest", nil, res) + if err != nil { + return nil, err + } + stream := &mockBidirectionalStreamingTestClient{res.Stream} + return stream, nil +} diff --git a/internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.fast.go b/internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.fast.go new file mode 100644 index 0000000000..ff559b7faf --- /dev/null +++ b/internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.fast.go @@ -0,0 +1,151 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by Fastpb v0.0.2. DO NOT EDIT. + +package mock + +import ( + fmt "fmt" + fastpb "github.com/cloudwego/fastpb" +) + +var ( + _ = fmt.Errorf + _ = fastpb.Skip +) + +func (x *MockReq) FastRead(buf []byte, _type int8, number int32) (offset int, err error) { + switch number { + case 1: + offset, err = x.fastReadField1(buf, _type) + if err != nil { + goto ReadFieldError + } + default: + offset, err = fastpb.Skip(buf, _type, number) + if err != nil { + goto SkipFieldError + } + } + return offset, nil +SkipFieldError: + return offset, fmt.Errorf("%T cannot parse invalid wire-format data, error: %s", x, err) +ReadFieldError: + return offset, fmt.Errorf("%T read field %d '%s' error: %s", x, number, fieldIDToName_MockReq[number], err) +} + +func (x *MockReq) fastReadField1(buf []byte, _type int8) (offset int, err error) { + x.Message, offset, err = fastpb.ReadString(buf, _type) + return offset, err +} + +func (x *MockResp) FastRead(buf []byte, _type int8, number int32) (offset int, err error) { + switch number { + case 1: + offset, err = x.fastReadField1(buf, _type) + if err != nil { + goto ReadFieldError + } + default: + offset, err = fastpb.Skip(buf, _type, number) + if err != nil { + goto SkipFieldError + } + } + return offset, nil +SkipFieldError: + return offset, fmt.Errorf("%T cannot parse invalid wire-format data, error: %s", x, err) +ReadFieldError: + return offset, fmt.Errorf("%T read field %d '%s' error: %s", x, number, fieldIDToName_MockResp[number], err) +} + +func (x *MockResp) fastReadField1(buf []byte, _type int8) (offset int, err error) { + x.Message, offset, err = fastpb.ReadString(buf, _type) + return offset, err +} + +func (x *MockReq) FastWrite(buf []byte) (offset int) { + if x == nil { + return offset + } + offset += x.fastWriteField1(buf[offset:]) + return offset +} + +func (x *MockReq) fastWriteField1(buf []byte) (offset int) { + if x.Message == "" { + return offset + } + offset += fastpb.WriteString(buf[offset:], 1, x.GetMessage()) + return offset +} + +func (x *MockResp) FastWrite(buf []byte) (offset int) { + if x == nil { + return offset + } + offset += x.fastWriteField1(buf[offset:]) + return offset +} + +func (x *MockResp) fastWriteField1(buf []byte) (offset int) { + if x.Message == "" { + return offset + } + offset += fastpb.WriteString(buf[offset:], 1, x.GetMessage()) + return offset +} + +func (x *MockReq) Size() (n int) { + if x == nil { + return n + } + n += x.sizeField1() + return n +} + +func (x *MockReq) sizeField1() (n int) { + if x.Message == "" { + return n + } + n += fastpb.SizeString(1, x.GetMessage()) + return n +} + +func (x *MockResp) Size() (n int) { + if x == nil { + return n + } + n += x.sizeField1() + return n +} + +func (x *MockResp) sizeField1() (n int) { + if x.Message == "" { + return n + } + n += fastpb.SizeString(1, x.GetMessage()) + return n +} + +var fieldIDToName_MockReq = map[int32]string{ + 1: "Message", +} + +var fieldIDToName_MockResp = map[int32]string{ + 1: "Message", +} diff --git a/internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.go b/internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.go new file mode 100644 index 0000000000..174cab556e --- /dev/null +++ b/internal/mocks/proto/kitex_gen/pbapi/mock/pbapi.pb.go @@ -0,0 +1,276 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v4.25.3 +// source: pbapi.proto + +package mock + +import ( + context "context" + streaming "github.com/cloudwego/kitex/pkg/streaming" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MockReq struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *MockReq) Reset() { + *x = MockReq{} + if protoimpl.UnsafeEnabled { + mi := &file_pbapi_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MockReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MockReq) ProtoMessage() {} + +func (x *MockReq) ProtoReflect() protoreflect.Message { + mi := &file_pbapi_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MockReq.ProtoReflect.Descriptor instead. +func (*MockReq) Descriptor() ([]byte, []int) { + return file_pbapi_proto_rawDescGZIP(), []int{0} +} + +func (x *MockReq) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type MockResp struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *MockResp) Reset() { + *x = MockResp{} + if protoimpl.UnsafeEnabled { + mi := &file_pbapi_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MockResp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MockResp) ProtoMessage() {} + +func (x *MockResp) ProtoReflect() protoreflect.Message { + mi := &file_pbapi_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MockResp.ProtoReflect.Descriptor instead. +func (*MockResp) Descriptor() ([]byte, []int) { + return file_pbapi_proto_rawDescGZIP(), []int{1} +} + +func (x *MockResp) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +var File_pbapi_proto protoreflect.FileDescriptor + +var file_pbapi_proto_rawDesc = []byte{ + 0x0a, 0x0b, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x70, + 0x62, 0x61, 0x70, 0x69, 0x22, 0x23, 0x0a, 0x07, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x12, + 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x24, 0x0a, 0x08, 0x4d, 0x6f, 0x63, + 0x6b, 0x52, 0x65, 0x73, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, + 0xf3, 0x01, 0x0a, 0x04, 0x4d, 0x6f, 0x63, 0x6b, 0x12, 0x2e, 0x0a, 0x09, 0x55, 0x6e, 0x61, 0x72, + 0x79, 0x54, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x2e, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, + 0x63, 0x6b, 0x52, 0x65, 0x71, 0x1a, 0x0f, 0x2e, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, + 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x22, 0x00, 0x12, 0x3a, 0x0a, 0x13, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x54, 0x65, 0x73, 0x74, 0x12, + 0x0e, 0x2e, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x1a, + 0x0f, 0x2e, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, + 0x22, 0x00, 0x28, 0x01, 0x12, 0x3a, 0x0a, 0x13, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x54, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x2e, 0x70, 0x62, + 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x1a, 0x0f, 0x2e, 0x70, 0x62, + 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x22, 0x00, 0x30, 0x01, + 0x12, 0x43, 0x0a, 0x1a, 0x42, 0x69, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x61, + 0x6c, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x54, 0x65, 0x73, 0x74, 0x12, 0x0e, + 0x2e, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x1a, 0x0f, + 0x2e, 0x70, 0x62, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6f, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x22, + 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x41, 0x5a, 0x3f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x77, 0x65, 0x67, 0x6f, 0x2f, 0x6b, 0x69, + 0x74, 0x65, 0x78, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6d, 0x6f, 0x63, + 0x6b, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6b, 0x69, 0x74, 0x65, 0x78, 0x5f, 0x67, + 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x61, 0x70, 0x69, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pbapi_proto_rawDescOnce sync.Once + file_pbapi_proto_rawDescData = file_pbapi_proto_rawDesc +) + +func file_pbapi_proto_rawDescGZIP() []byte { + file_pbapi_proto_rawDescOnce.Do(func() { + file_pbapi_proto_rawDescData = protoimpl.X.CompressGZIP(file_pbapi_proto_rawDescData) + }) + return file_pbapi_proto_rawDescData +} + +var file_pbapi_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_pbapi_proto_goTypes = []interface{}{ + (*MockReq)(nil), // 0: pbapi.MockReq + (*MockResp)(nil), // 1: pbapi.MockResp +} +var file_pbapi_proto_depIdxs = []int32{ + 0, // 0: pbapi.Mock.UnaryTest:input_type -> pbapi.MockReq + 0, // 1: pbapi.Mock.ClientStreamingTest:input_type -> pbapi.MockReq + 0, // 2: pbapi.Mock.ServerStreamingTest:input_type -> pbapi.MockReq + 0, // 3: pbapi.Mock.BidirectionalStreamingTest:input_type -> pbapi.MockReq + 1, // 4: pbapi.Mock.UnaryTest:output_type -> pbapi.MockResp + 1, // 5: pbapi.Mock.ClientStreamingTest:output_type -> pbapi.MockResp + 1, // 6: pbapi.Mock.ServerStreamingTest:output_type -> pbapi.MockResp + 1, // 7: pbapi.Mock.BidirectionalStreamingTest:output_type -> pbapi.MockResp + 4, // [4:8] is the sub-list for method output_type + 0, // [0:4] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pbapi_proto_init() } +func file_pbapi_proto_init() { + if File_pbapi_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pbapi_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MockReq); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pbapi_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MockResp); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pbapi_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pbapi_proto_goTypes, + DependencyIndexes: file_pbapi_proto_depIdxs, + MessageInfos: file_pbapi_proto_msgTypes, + }.Build() + File_pbapi_proto = out.File + file_pbapi_proto_rawDesc = nil + file_pbapi_proto_goTypes = nil + file_pbapi_proto_depIdxs = nil +} + +var _ context.Context + +// Code generated by Kitex v0.9.1. DO NOT EDIT. + +type Mock interface { + UnaryTest(ctx context.Context, req *MockReq) (res *MockResp, err error) + ClientStreamingTest(stream Mock_ClientStreamingTestServer) (err error) + ServerStreamingTest(req *MockReq, stream Mock_ServerStreamingTestServer) (err error) + BidirectionalStreamingTest(stream Mock_BidirectionalStreamingTestServer) (err error) +} + +type Mock_ClientStreamingTestServer interface { + streaming.Stream + Recv() (*MockReq, error) + SendAndClose(*MockResp) error +} + +type Mock_ServerStreamingTestServer interface { + streaming.Stream + Send(*MockResp) error +} + +type Mock_BidirectionalStreamingTestServer interface { + streaming.Stream + Recv() (*MockReq, error) + Send(*MockResp) error +} diff --git a/internal/mocks/proto/kitex_gen/pbapi/mock/server.go b/internal/mocks/proto/kitex_gen/pbapi/mock/server.go new file mode 100644 index 0000000000..73c8c9fbfd --- /dev/null +++ b/internal/mocks/proto/kitex_gen/pbapi/mock/server.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by Kitex v0.9.1. DO NOT EDIT. +package mock + +import ( + server "github.com/cloudwego/kitex/server" +) + +// NewServer creates a server.Server with the given handler and options. +func NewServer(handler Mock, opts ...server.Option) server.Server { + var options []server.Option + + options = append(options, opts...) + + svr := server.NewServer(options...) + if err := svr.RegisterService(serviceInfo(), handler); err != nil { + panic(err) + } + return svr +} + +func RegisterService(svr server.Server, handler Mock, opts ...server.RegisterOption) error { + return svr.RegisterService(serviceInfo(), handler, opts...) +} diff --git a/internal/mocks/proto/pbapi.proto b/internal/mocks/proto/pbapi.proto new file mode 100644 index 0000000000..86d1e82039 --- /dev/null +++ b/internal/mocks/proto/pbapi.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package pbapi; + +option go_package = "pbapi"; + +message MockReq { + string message = 1; +} + +message MockResp { + string message = 1; +} + +service Mock { + rpc UnaryTest (MockReq) returns (MockResp) {} + rpc ClientStreamingTest (stream MockReq) returns (MockResp) {} + rpc ServerStreamingTest (MockReq) returns (stream MockResp) {} + rpc BidirectionalStreamingTest (stream MockReq) returns (stream MockResp) {} +} \ No newline at end of file diff --git a/pkg/generic/descriptor/descriptor.go b/pkg/generic/descriptor/descriptor.go index 6431f214e8..8f742836b0 100644 --- a/pkg/generic/descriptor/descriptor.go +++ b/pkg/generic/descriptor/descriptor.go @@ -22,6 +22,8 @@ import ( "os" dthrift "github.com/cloudwego/dynamicgo/thrift" + + "github.com/cloudwego/kitex/pkg/serviceinfo" ) var isGoTagAliasDisabled = os.Getenv("KITEX_GENERIC_GOTAG_ALIAS_DISABLED") == "True" @@ -91,6 +93,7 @@ type FunctionDescriptor struct { Request *TypeDescriptor Response *TypeDescriptor HasRequestBase bool + StreamingMode serviceinfo.StreamingMode } // ServiceDescriptor idl service descriptor diff --git a/pkg/generic/generic.go b/pkg/generic/generic.go index 93eb8869e4..fbf84073e6 100644 --- a/pkg/generic/generic.go +++ b/pkg/generic/generic.go @@ -49,8 +49,9 @@ type Generic interface { // Method information type Method struct { - Name string - Oneway bool + Name string + Oneway bool + StreamingMode serviceinfo.StreamingMode } // BinaryThriftGeneric raw thrift binary Generic @@ -217,7 +218,7 @@ func (g *binaryThriftGeneric) PayloadCodec() remote.PayloadCodec { } func (g *binaryThriftGeneric) GetMethod(req interface{}, method string) (*Method, error) { - return &Method{method, false}, nil + return &Method{Name: method, Oneway: false}, nil } func (g *binaryThriftGeneric) Close() error { diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index b24b44a401..9d5f09bdb4 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -120,7 +120,6 @@ var ( _ WithCodec = (*Args)(nil) ) -// Deprecated: it's not used by kitex anymore. // SetCodec ... func (g *Args) SetCodec(inner interface{}) { g.inner = inner @@ -200,7 +199,6 @@ var ( _ WithCodec = (*Result)(nil) ) -// Deprecated: it's not used by kitex anymore. // SetCodec ... func (r *Result) SetCodec(inner interface{}) { r.inner = inner diff --git a/pkg/generic/generic_service_test.go b/pkg/generic/generic_service_test.go index 6b6b288c30..5ffb910db5 100644 --- a/pkg/generic/generic_service_test.go +++ b/pkg/generic/generic_service_test.go @@ -71,7 +71,7 @@ func TestGenericService(t *testing.T) { // read ok a.SetCodec(rInner) err = a.Read(ctx, method, 0, tProto) - test.Assert(t, err == nil) + test.Assert(t, err == nil, err) // Result... result := newGenericServiceCallResult() diff --git a/pkg/generic/grpcjsonpb_test/generic_init.go b/pkg/generic/grpcjsonpb_test/generic_init.go new file mode 100644 index 0000000000..5c83f4d4bb --- /dev/null +++ b/pkg/generic/grpcjsonpb_test/generic_init.go @@ -0,0 +1,148 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "context" + "fmt" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/genericclient" + "github.com/cloudwego/kitex/internal/mocks/proto/kitex_gen/pbapi/mock" + "github.com/cloudwego/kitex/pkg/generic" + "github.com/cloudwego/kitex/server" + "github.com/cloudwego/kitex/transport" +) + +func newGenericClient(g generic.Generic, targetIPPort string) genericclient.Client { + cli, err := genericclient.NewStreamingClient("destService", g, + client.WithTransportProtocol(transport.GRPC), + client.WithHostPorts(targetIPPort), + ) + if err != nil { + panic(err) + } + return cli +} + +func newMockTestServer(handler mock.Mock, addr net.Addr, opts ...server.Option) server.Server { + opts = append(opts, server.WithServiceAddr(addr)) + svr := mock.NewServer(handler, opts...) + go func() { + err := svr.Run() + if err != nil { + panic(err) + } + }() + return svr +} + +var _ mock.Mock = &StreamingTestImpl{} + +type StreamingTestImpl struct{} + +func (s *StreamingTestImpl) UnaryTest(ctx context.Context, req *mock.MockReq) (resp *mock.MockResp, err error) { + fmt.Println("UnaryTest called") + resp = &mock.MockResp{} + resp.Message = "hello " + req.Message + return +} + +func (s *StreamingTestImpl) ClientStreamingTest(stream mock.Mock_ClientStreamingTestServer) (err error) { + fmt.Println("ClientStreamingTest called") + var msgs []string + for { + req, err := stream.Recv() + if err != nil { + if err == io.EOF { + break + } + return err + } + fmt.Printf("Recv: %s\n", req.Message) + msgs = append(msgs, req.Message) + time.Sleep(time.Second) + } + return stream.SendAndClose(&mock.MockResp{Message: "all message: " + strings.Join(msgs, ", ")}) +} + +func (s *StreamingTestImpl) ServerStreamingTest(req *mock.MockReq, stream mock.Mock_ServerStreamingTestServer) (err error) { + fmt.Println("ServerStreamingTest called") + resp := &mock.MockResp{} + for i := 0; i < 3; i++ { + resp.Message = fmt.Sprintf("%v -> %dth response", req.Message, i) + err := stream.Send(resp) + if err != nil { + return err + } + time.Sleep(time.Second) + } + return +} + +func (s *StreamingTestImpl) BidirectionalStreamingTest(stream mock.Mock_BidirectionalStreamingTestServer) (err error) { + fmt.Println("BidirectionalStreamingTest called") + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer func() { + if p := recover(); p != nil { + err = fmt.Errorf("panic: %v", p) + } + wg.Done() + }() + defer stream.Close() + for { + msg, recvErr := stream.Recv() + if recvErr == io.EOF { + return + } else if recvErr != nil { + err = recvErr + return + } + fmt.Printf("BidirectionaStreamingTest: received message = %s\n", msg.Message) + time.Sleep(time.Second) + } + }() + + go func() { + defer func() { + if p := recover(); p != nil { + err = fmt.Errorf("panic: %v", p) + } + wg.Done() + }() + resp := &mock.MockResp{} + for i := 0; i < 3; i++ { + resp.Message = fmt.Sprintf("%dth response", i) + if sendErr := stream.Send(resp); sendErr != nil { + err = sendErr + return + } + fmt.Printf("BidirectionaStreamingTest: sent message = %s\n", resp) + time.Sleep(time.Second) + } + }() + wg.Wait() + return +} diff --git a/pkg/generic/grpcjsonpb_test/generic_test.go b/pkg/generic/grpcjsonpb_test/generic_test.go new file mode 100644 index 0000000000..e61750c781 --- /dev/null +++ b/pkg/generic/grpcjsonpb_test/generic_test.go @@ -0,0 +1,163 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "context" + "fmt" + "io" + "net" + "reflect" + "strings" + "sync" + "testing" + "time" + + dproto "github.com/cloudwego/dynamicgo/proto" + "github.com/tidwall/gjson" + + "github.com/cloudwego/kitex/client/genericclient" + "github.com/cloudwego/kitex/internal/mocks/proto/kitex_gen/pbapi/mock" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic" + "github.com/cloudwego/kitex/server" +) + +func TestClientStreaming(t *testing.T) { + ctx := context.Background() + addr := test.GetLocalAddress() + + svr := initMockTestServer(new(StreamingTestImpl), addr) + defer svr.Stop() + + cli := initStreamingClient(t, ctx, addr, "./idl/pbapi.proto") + streamCli, err := genericclient.NewClientStreaming(ctx, cli, "ClientStreamingTest") + test.Assert(t, err == nil, err) + for i := 0; i < 3; i++ { + req := fmt.Sprintf(`{"message": "grpc client streaming generic %dth request"}`, i) + err = streamCli.Send(req) + test.Assert(t, err == nil) + time.Sleep(time.Second) + } + resp, err := streamCli.CloseAndRecv() + test.Assert(t, err == nil) + strResp, ok := resp.(string) + test.Assert(t, ok) + fmt.Printf("clientStreaming message received: %v\n", strResp) + test.Assert(t, reflect.DeepEqual(gjson.Get(strResp, "message").String(), + "all message: grpc client streaming generic 0th request, grpc client streaming generic 1th request, grpc client streaming generic 2th request")) +} + +func TestServerStreaming(t *testing.T) { + ctx := context.Background() + addr := test.GetLocalAddress() + + svr := initMockTestServer(new(StreamingTestImpl), addr) + defer svr.Stop() + + cli := initStreamingClient(t, ctx, addr, "./idl/pbapi.proto") + streamCli, err := genericclient.NewServerStreaming(ctx, cli, "ServerStreamingTest", `{"message": "grpc server streaming generic request"}`) + test.Assert(t, err == nil, err) + for { + resp, err := streamCli.Recv() + if err != nil { + test.Assert(t, err == io.EOF) + fmt.Println("serverStreaming message receive done") + break + } else { + strResp, ok := resp.(string) + test.Assert(t, ok) + fmt.Printf("serverStreaming message received: %s\n", strResp) + test.Assert(t, strings.Contains(strResp, "grpc server streaming generic request ->")) + } + time.Sleep(time.Second) + } +} + +func TestBidirectionalStreaming(t *testing.T) { + ctx := context.Background() + addr := test.GetLocalAddress() + + svr := initMockTestServer(new(StreamingTestImpl), addr) + defer svr.Stop() + + cli := initStreamingClient(t, ctx, addr, "./idl/pbapi.proto") + streamCli, err := genericclient.NewBidirectionalStreaming(ctx, cli, "BidirectionalStreamingTest") + test.Assert(t, err == nil) + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + defer streamCli.Close() + for i := 0; i < 3; i++ { + req := fmt.Sprintf(`{"message": "grpc bidirectional streaming generic %dth request"}`, i) + err = streamCli.Send(req) + test.Assert(t, err == nil) + fmt.Printf("BidirectionalStreamingTest send: req = %s\n", req) + } + }() + + go func() { + defer wg.Done() + for { + resp, err := streamCli.Recv() + if err != nil { + test.Assert(t, err == io.EOF) + fmt.Println("bidirectionalStreaming message receive done") + break + } else { + strResp, ok := resp.(string) + test.Assert(t, ok) + fmt.Printf("bidirectionalStreaming message received: %s\n", strResp) + test.Assert(t, strings.Contains(strResp, "th response")) + } + time.Sleep(time.Second) + } + }() + wg.Wait() +} + +func TestUnary(t *testing.T) { + ctx := context.Background() + addr := test.GetLocalAddress() + + svr := initMockTestServer(new(StreamingTestImpl), addr) + defer svr.Stop() + + cli := initStreamingClient(t, ctx, addr, "./idl/pbapi.proto") + resp, err := cli.GenericCall(ctx, "UnaryTest", `{"message": "unary request"}`) + test.Assert(t, err == nil) + strResp, ok := resp.(string) + test.Assert(t, ok) + test.Assert(t, reflect.DeepEqual(gjson.Get(strResp, "message").String(), "hello unary request")) +} + +func initStreamingClient(t *testing.T, ctx context.Context, addr, idl string) genericclient.Client { + dOpts := dproto.Options{} + p, err := generic.NewPbFileProviderWithDynamicGo(idl, ctx, dOpts) + test.Assert(t, err == nil) + g, err := generic.JSONPbGeneric(p) + test.Assert(t, err == nil) + return newGenericClient(g, addr) +} + +func initMockTestServer(handler mock.Mock, address string) server.Server { + addr, _ := net.ResolveTCPAddr("tcp", address) + return newMockTestServer(handler, addr) +} diff --git a/pkg/generic/grpcjsonpb_test/idl/pbapi.proto b/pkg/generic/grpcjsonpb_test/idl/pbapi.proto new file mode 100644 index 0000000000..86d1e82039 --- /dev/null +++ b/pkg/generic/grpcjsonpb_test/idl/pbapi.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package pbapi; + +option go_package = "pbapi"; + +message MockReq { + string message = 1; +} + +message MockResp { + string message = 1; +} + +service Mock { + rpc UnaryTest (MockReq) returns (MockResp) {} + rpc ClientStreamingTest (stream MockReq) returns (MockResp) {} + rpc ServerStreamingTest (MockReq) returns (stream MockResp) {} + rpc BidirectionalStreamingTest (stream MockReq) returns (stream MockResp) {} +} \ No newline at end of file diff --git a/pkg/generic/httppbthrift_codec.go b/pkg/generic/httppbthrift_codec.go index cde66eec43..ea59513f81 100644 --- a/pkg/generic/httppbthrift_codec.go +++ b/pkg/generic/httppbthrift_codec.go @@ -20,7 +20,6 @@ import ( "context" "errors" "io" - "io/ioutil" "net/http" "strings" "sync/atomic" @@ -86,7 +85,7 @@ func (c *httpPbThriftCodec) getMethod(req interface{}) (*Method, error) { if err != nil { return nil, err } - return &Method{function.Name, function.Oneway}, nil + return &Method{function.Name, function.Oneway, function.StreamingMode}, nil } func (c *httpPbThriftCodec) getMessageReaderWriter() interface{} { @@ -122,7 +121,7 @@ func (c *httpPbThriftCodec) Close() error { } } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *httpPbThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { @@ -138,7 +137,7 @@ func (c *httpPbThriftCodec) Marshal(ctx context.Context, msg remote.Message, out return thriftCodec.Marshal(ctx, msg, out) } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *httpPbThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err @@ -177,7 +176,7 @@ func FromHTTPPbRequest(req *http.Request) (*HTTPRequest, error) { // body == nil if from Get request return customReq, nil } - if customReq.RawBody, err = ioutil.ReadAll(b); err != nil { + if customReq.RawBody, err = io.ReadAll(b); err != nil { return nil, err } if len(customReq.RawBody) == 0 { diff --git a/pkg/generic/httppbthrift_codec_test.go b/pkg/generic/httppbthrift_codec_test.go index 7871fac1cf..2af104805f 100644 --- a/pkg/generic/httppbthrift_codec_test.go +++ b/pkg/generic/httppbthrift_codec_test.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/kitex/internal/test" gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) func TestFromHTTPPbRequest(t *testing.T) { @@ -65,6 +66,7 @@ func TestHTTPPbThriftCodec(t *testing.T) { method, err = htc.getMethod(hreq) test.Assert(t, err == nil, err) test.Assert(t, method.Name == "Echo") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, htc.svcName == "ExampleService") rw := htc.getMessageReaderWriter() diff --git a/pkg/generic/httpthrift_codec.go b/pkg/generic/httpthrift_codec.go index 8e40a7ba8e..b66a37ed11 100644 --- a/pkg/generic/httpthrift_codec.go +++ b/pkg/generic/httpthrift_codec.go @@ -119,7 +119,7 @@ func (c *httpThriftCodec) getMethod(req interface{}) (*Method, error) { if err != nil { return nil, err } - return &Method{function.Name, function.Oneway}, nil + return &Method{function.Name, function.Oneway, function.StreamingMode}, nil } func (c *httpThriftCodec) Name() string { @@ -130,7 +130,7 @@ func (c *httpThriftCodec) Close() error { return c.provider.Close() } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *httpThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) if !ok { @@ -147,7 +147,7 @@ func (c *httpThriftCodec) Marshal(ctx context.Context, msg remote.Message, out r return thriftCodec.Marshal(ctx, msg, out) } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *httpThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err diff --git a/pkg/generic/httpthrift_codec_test.go b/pkg/generic/httpthrift_codec_test.go index 9186478224..245b1906fc 100644 --- a/pkg/generic/httpthrift_codec_test.go +++ b/pkg/generic/httpthrift_codec_test.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/kitex/internal/test" gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) var customJson = sonic.Config{ @@ -62,6 +63,7 @@ func TestHttpThriftCodec(t *testing.T) { // right method, err = htc.getMethod(req) test.Assert(t, err == nil && method.Name == "BinaryEcho") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, htc.svcName == "ExampleService") rw := htc.getMessageReaderWriter() @@ -97,6 +99,7 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) { // right method, err = htc.getMethod(req) test.Assert(t, err == nil && method.Name == "BinaryEcho") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, htc.svcName == "ExampleService") rw := htc.getMessageReaderWriter() diff --git a/pkg/generic/jsonpb_codec.go b/pkg/generic/jsonpb_codec.go index 4914ebba03..ac58e21ed1 100644 --- a/pkg/generic/jsonpb_codec.go +++ b/pkg/generic/jsonpb_codec.go @@ -81,7 +81,7 @@ func (c *jsonPbCodec) getMethod(req interface{}, method string) (*Method, error) } // protobuf does not have oneway methods - return &Method{method, false}, nil + return &Method{method, false, getStreamingMode(fnSvc)}, nil } func (c *jsonPbCodec) Name() string { @@ -92,7 +92,7 @@ func (c *jsonPbCodec) Close() error { return c.provider.Close() } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *jsonPbCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { method := msg.RPCInfo().Invocation().MethodName() if method == "" { @@ -110,7 +110,7 @@ func (c *jsonPbCodec) Marshal(ctx context.Context, msg remote.Message, out remot return pbCodec.Marshal(ctx, msg, out) } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *jsonPbCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err @@ -123,3 +123,17 @@ func (c *jsonPbCodec) Unmarshal(ctx context.Context, msg remote.Message, in remo return pbCodec.Unmarshal(ctx, msg, in) } + +func getStreamingMode(fnSvc *dproto.MethodDescriptor) serviceinfo.StreamingMode { + streamingMode := serviceinfo.StreamingNone + isClientStreaming := fnSvc.IsClientStreaming() + isServerStreaming := fnSvc.IsServerStreaming() + if isClientStreaming && isServerStreaming { + streamingMode = serviceinfo.StreamingBidirectional + } else if isClientStreaming { + streamingMode = serviceinfo.StreamingClient + } else if isServerStreaming { + streamingMode = serviceinfo.StreamingServer + } + return streamingMode +} diff --git a/pkg/generic/jsonpb_codec_test.go b/pkg/generic/jsonpb_codec_test.go index a1b823508d..8913d3369b 100644 --- a/pkg/generic/jsonpb_codec_test.go +++ b/pkg/generic/jsonpb_codec_test.go @@ -25,9 +25,13 @@ import ( "github.com/cloudwego/kitex/internal/test" gproto "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var echoIDLPath = "./jsonpb_test/idl/echo.proto" +var ( + echoIDLPath = "./jsonpb_test/idl/echo.proto" + testIDLPath = "./grpcjsonpb_test/idl/pbapi.proto" +) func TestRun(t *testing.T) { t.Run("TestJsonPbCodec", TestJsonPbCodec) @@ -46,6 +50,7 @@ func TestJsonPbCodec(t *testing.T) { method, err := jpc.getMethod(nil, "Echo") test.Assert(t, err == nil) test.Assert(t, method.Name == "Echo") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, jpc.svcName == "Echo") rw := jpc.getMessageReaderWriter() @@ -53,4 +58,28 @@ func TestJsonPbCodec(t *testing.T) { test.Assert(t, ok) _, ok = rw.(gproto.MessageReader) test.Assert(t, ok) + + p, err = NewPbFileProviderWithDynamicGo(testIDLPath, context.Background(), opts) + test.Assert(t, err == nil) + + jpc = newJsonPbCodec(p, gOpts) + defer jpc.Close() + + method, err = jpc.getMethod(nil, "ClientStreamingTest") + test.Assert(t, err == nil) + test.Assert(t, method.Name == "ClientStreamingTest") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingClient) + test.Assert(t, jpc.svcName == "Mock") + + method, err = jpc.getMethod(nil, "ServerStreamingTest") + test.Assert(t, err == nil) + test.Assert(t, method.Name == "ServerStreamingTest") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingServer) + test.Assert(t, jpc.svcName == "Mock") + + method, err = jpc.getMethod(nil, "BidirectionalStreamingTest") + test.Assert(t, err == nil) + test.Assert(t, method.Name == "BidirectionalStreamingTest") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingBidirectional) + test.Assert(t, jpc.svcName == "Mock") } diff --git a/pkg/generic/jsonthrift_codec.go b/pkg/generic/jsonthrift_codec.go index d418e9844c..377d507ca1 100644 --- a/pkg/generic/jsonthrift_codec.go +++ b/pkg/generic/jsonthrift_codec.go @@ -108,7 +108,7 @@ func (c *jsonThriftCodec) getMethod(req interface{}, method string) (*Method, er if err != nil { return nil, err } - return &Method{method, fnSvc.Oneway}, nil + return &Method{method, fnSvc.Oneway, fnSvc.StreamingMode}, nil } func (c *jsonThriftCodec) Name() string { @@ -119,7 +119,7 @@ func (c *jsonThriftCodec) Close() error { return c.provider.Close() } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *jsonThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { method := msg.RPCInfo().Invocation().MethodName() if method == "" { @@ -143,7 +143,7 @@ func (c *jsonThriftCodec) Marshal(ctx context.Context, msg remote.Message, out r return thriftCodec.Marshal(ctx, msg, out) } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *jsonThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err diff --git a/pkg/generic/jsonthrift_codec_test.go b/pkg/generic/jsonthrift_codec_test.go index ecdc351889..82ff0bb1fe 100644 --- a/pkg/generic/jsonthrift_codec_test.go +++ b/pkg/generic/jsonthrift_codec_test.go @@ -23,6 +23,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/thrift" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) func TestJsonThriftCodec(t *testing.T) { @@ -40,6 +41,7 @@ func TestJsonThriftCodec(t *testing.T) { method, err := jtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, jtc.svcName == "Mock") rw := jtc.getMessageReaderWriter() @@ -69,6 +71,7 @@ func TestJsonThriftCodecWithDynamicGo(t *testing.T) { method, err := jtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) rw := jtc.getMessageReaderWriter() _, ok := rw.(thrift.MessageWriter) @@ -88,6 +91,7 @@ func TestJsonThriftCodec_SelfRef(t *testing.T) { method, err := jtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) rw := jtc.getMessageReaderWriter() _, ok := rw.(thrift.MessageWriter) @@ -107,6 +111,7 @@ func TestJsonThriftCodec_SelfRef(t *testing.T) { method, err := jtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, jtc.svcName == "Mock") rw := jtc.getMessageReaderWriter() diff --git a/pkg/generic/mapthrift_codec.go b/pkg/generic/mapthrift_codec.go index 953734349f..78e2a449e4 100644 --- a/pkg/generic/mapthrift_codec.go +++ b/pkg/generic/mapthrift_codec.go @@ -100,7 +100,7 @@ func (c *mapThriftCodec) getMethod(req interface{}, method string) (*Method, err if err != nil { return nil, err } - return &Method{method, fnSvc.Oneway}, nil + return &Method{method, fnSvc.Oneway, fnSvc.StreamingMode}, nil } func (c *mapThriftCodec) Name() string { @@ -111,7 +111,7 @@ func (c *mapThriftCodec) Close() error { return c.provider.Close() } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *mapThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { method := msg.RPCInfo().Invocation().MethodName() if method == "" { @@ -130,7 +130,7 @@ func (c *mapThriftCodec) Marshal(ctx context.Context, msg remote.Message, out re return thriftCodec.Marshal(ctx, msg, out) } -// Deprecated: it's not used by kitex anymore. replaced by GetMessageReaderWriter +// Deprecated: it's not used by kitex anymore. replaced by generic.MessageReaderWriter func (c *mapThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { return err diff --git a/pkg/generic/mapthrift_codec_test.go b/pkg/generic/mapthrift_codec_test.go index 4d9b7948e9..86630b842d 100644 --- a/pkg/generic/mapthrift_codec_test.go +++ b/pkg/generic/mapthrift_codec_test.go @@ -21,6 +21,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/thrift" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) func TestMapThriftCodec(t *testing.T) { @@ -33,6 +34,7 @@ func TestMapThriftCodec(t *testing.T) { method, err := mtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, mtc.svcName == "Mock") rw := mtc.getMessageReaderWriter() @@ -52,6 +54,7 @@ func TestMapThriftCodecSelfRef(t *testing.T) { method, err := mtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, mtc.svcName == "Mock") rw := mtc.getMessageReaderWriter() @@ -71,6 +74,7 @@ func TestMapThriftCodecForJSON(t *testing.T) { method, err := mtc.getMethod(nil, "Test") test.Assert(t, err == nil) test.Assert(t, method.Name == "Test") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) test.Assert(t, mtc.svcName == "Mock") rw := mtc.getMessageReaderWriter() diff --git a/pkg/generic/thrift/json.go b/pkg/generic/thrift/json.go index 6c2be7fca8..c697c29f15 100644 --- a/pkg/generic/thrift/json.go +++ b/pkg/generic/thrift/json.go @@ -84,7 +84,7 @@ func (m *WriteJSON) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) if err != nil { - return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, m.svcDsc.DynamicGoDsc.Name()) + return fmt.Errorf("missing method: %s in service: %s", method, m.svcDsc.Name) } typeDsc := fnDsc.Request if !isClient { diff --git a/pkg/generic/thrift/parse.go b/pkg/generic/thrift/parse.go index 03a80b9b3b..7576a7953b 100644 --- a/pkg/generic/thrift/parse.go +++ b/pkg/generic/thrift/parse.go @@ -22,12 +22,14 @@ import ( "fmt" "runtime/debug" + "github.com/cloudwego/thriftgo/generator/golang/streaming" "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) const ( @@ -140,6 +142,11 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv if len(fn.Arguments) == 0 { return fmt.Errorf("empty arguments in function: %s", fn.Name) } + st, err := streaming.ParseStreaming(fn) + if err != nil { + return err + } + mode := streamingMode(st) // only support single argument field := fn.Arguments[0] req := &descriptor.TypeDescriptor{ @@ -213,6 +220,7 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv Request: req, Response: resp, HasRequestBase: hasRequestBase, + StreamingMode: mode, } defer func() { if ret := recover(); ret != nil { @@ -234,6 +242,22 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv return nil } +func streamingMode(st *streaming.Streaming) serviceinfo.StreamingMode { + if st.BidirectionalStreaming { + return serviceinfo.StreamingBidirectional + } + if st.ClientStreaming { + return serviceinfo.StreamingClient + } + if st.ServerStreaming { + return serviceinfo.StreamingServer + } + if st.Unary { + return serviceinfo.StreamingUnary + } + return serviceinfo.StreamingNone +} + // reuse builtin types var builtinTypes = map[string]*descriptor.TypeDescriptor{ "void": {Name: "void", Type: descriptor.VOID, Struct: new(descriptor.StructDescriptor)}, diff --git a/pkg/generic/thrift/parse_test.go b/pkg/generic/thrift/parse_test.go index 418be486bf..30a1088f5e 100644 --- a/pkg/generic/thrift/parse_test.go +++ b/pkg/generic/thrift/parse_test.go @@ -25,6 +25,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) var httpIDL = ` @@ -407,6 +408,25 @@ service DemoService { } ` +const streamingIDL = ` +namespace go echo + +struct Request { + 1: optional string message, +} + +struct Response { + 1: optional string message, +} + +service TestService { + Response EchoBidirectional (1: Request req) (streaming.mode="bidirectional"), + Response EchoClient (1: Request req) (streaming.mode="client"), + Response EchoServer (1: Request req) (streaming.mode="server"), + Response EchoUnary (1: Request req) (streaming.mode="unary"), +} +` + func TestDefaultValue(t *testing.T) { demo, err := parser.ParseString("demo.thrift", defaultValueDemoIDL) test.Assert(t, err == nil) @@ -440,6 +460,30 @@ func TestDefaultValue(t *testing.T) { } } +func TestStreamingMode(t *testing.T) { + streaming, err := parser.ParseString("streaming.thrift", streamingIDL) + test.Assert(t, err == nil) + + dp, err := Parse(streaming, DefaultParseMode()) + test.Assert(t, err == nil) + + method, err := dp.LookupFunctionByMethod("EchoBidirectional") + test.Assert(t, err == nil) + test.Assert(t, method.StreamingMode == serviceinfo.StreamingBidirectional) + + method, err = dp.LookupFunctionByMethod("EchoClient") + test.Assert(t, err == nil) + test.Assert(t, method.StreamingMode == serviceinfo.StreamingClient) + + method, err = dp.LookupFunctionByMethod("EchoServer") + test.Assert(t, err == nil) + test.Assert(t, method.StreamingMode == serviceinfo.StreamingServer) + + method, err = dp.LookupFunctionByMethod("EchoUnary") + test.Assert(t, err == nil) + test.Assert(t, method.StreamingMode == serviceinfo.StreamingUnary) +} + func defaultValueDeepEqual(t *testing.T, defaultValue func(name string) interface{}) { test.Assert(t, defaultValue("a") == true) test.Assert(t, defaultValue("b") == byte(1)) diff --git a/pkg/remote/codec/grpc/grpc.go b/pkg/remote/codec/grpc/grpc.go index 0b417d2606..53a0e3708a 100644 --- a/pkg/remote/codec/grpc/grpc.go +++ b/pkg/remote/codec/grpc/grpc.go @@ -27,6 +27,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -132,6 +133,19 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo payload, err = proto.Marshal(t) case protobuf.ProtobufMsgCodec: payload, err = t.Marshal(nil) + case protobuf.MessageWriterWithContext: + methodName := message.RPCInfo().Invocation().MethodName() + if methodName == "" { + return errors.New("empty methodName in grpc Encode") + } + actualMsg, err := t.WritePb(ctx, methodName) + if err != nil { + return err + } + payload, ok = actualMsg.([]byte) + if !ok { + return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("grpc marshal message failed: %s", err.Error())) + } default: return ErrInvalidPayload } @@ -194,6 +208,12 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot return proto.Unmarshal(d, t) case protobuf.ProtobufMsgCodec: return t.Unmarshal(d) + case protobuf.MessageReaderWithMethodWithContext: + methodName := message.RPCInfo().Invocation().MethodName() + if methodName == "" { + return errors.New("empty methodName in grpc Decode") + } + return t.ReadPb(ctx, methodName, d) default: return ErrInvalidPayload } diff --git a/pkg/serviceinfo/serviceinfo.go b/pkg/serviceinfo/serviceinfo.go index 96fd4a7cab..c93126e48c 100644 --- a/pkg/serviceinfo/serviceinfo.go +++ b/pkg/serviceinfo/serviceinfo.go @@ -39,6 +39,12 @@ const ( CombineService = "CombineService" // CombineService_ is used when idl has a service named "CombineService" CombineService_ = "CombineService_" + // GenericClientStreamingMethod name + GenericClientStreamingMethod = "$GenericClientStreamingMethod" + // GenericServerStreamingMethod name + GenericServerStreamingMethod = "$GenericServerStreamingMethod" + // GenericBidirectionalStreamingMethod name + GenericBidirectionalStreamingMethod = "$GenericBidirectionalStreamingMethod" ) // ServiceInfo to record meta info of service @@ -91,6 +97,7 @@ func (i *ServiceInfo) MethodInfo(name string) MethodInfo { if i.GenericMethod != nil { return i.GenericMethod(name) } + // TODO: modify when server side supports grpc generic return i.Methods[GenericMethod] } return i.Methods[name] From 5ba0642482d166d502ff50a94f59f35922c94cb4 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 3 Jul 2024 18:01:34 +0800 Subject: [PATCH 06/70] feat: add PrependError for thriftgo (#1420) --- pkg/generic/thrift/base.go | 105 ++++++++++--------- pkg/protocol/bthrift/apache/exception.go | 28 ----- pkg/protocol/bthrift/exception.go | 54 +++++++++- pkg/protocol/bthrift/exception_test.go | 39 +++++++ pkg/remote/trans/netpollmux/control_frame.go | 19 ++-- 5 files changed, 155 insertions(+), 90 deletions(-) delete mode 100644 pkg/protocol/bthrift/apache/exception.go diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go index 8a3f05febb..0139ca24c3 100644 --- a/pkg/generic/thrift/base.go +++ b/pkg/generic/thrift/base.go @@ -19,6 +19,7 @@ package thrift import ( "fmt" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) @@ -110,18 +111,18 @@ func (p *TrafficEnv) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_TrafficEnv[fieldId]), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_TrafficEnv[fieldId]), err) SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *TrafficEnv) ReadField1(iprot thrift.TProtocol) error { @@ -166,13 +167,13 @@ func (p *TrafficEnv) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *TrafficEnv) writeField1(oprot thrift.TProtocol) (err error) { @@ -187,9 +188,9 @@ func (p *TrafficEnv) writeField1(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } func (p *TrafficEnv) writeField2(oprot thrift.TProtocol) (err error) { @@ -204,9 +205,9 @@ func (p *TrafficEnv) writeField2(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } func (p *TrafficEnv) String() string { @@ -408,18 +409,18 @@ func (p *Base) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Base[fieldId]), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Base[fieldId]), err) SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *Base) ReadField1(iprot thrift.TProtocol) error { @@ -535,13 +536,13 @@ func (p *Base) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *Base) writeField1(oprot thrift.TProtocol) (err error) { @@ -556,9 +557,9 @@ func (p *Base) writeField1(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } func (p *Base) writeField2(oprot thrift.TProtocol) (err error) { @@ -573,9 +574,9 @@ func (p *Base) writeField2(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } func (p *Base) writeField3(oprot thrift.TProtocol) (err error) { @@ -590,9 +591,9 @@ func (p *Base) writeField3(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } func (p *Base) writeField4(oprot thrift.TProtocol) (err error) { @@ -607,9 +608,9 @@ func (p *Base) writeField4(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) } func (p *Base) writeField5(oprot thrift.TProtocol) (err error) { @@ -626,9 +627,9 @@ func (p *Base) writeField5(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) } func (p *Base) writeField6(oprot thrift.TProtocol) (err error) { @@ -658,9 +659,9 @@ func (p *Base) writeField6(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) } func (p *Base) String() string { @@ -788,18 +789,18 @@ func (p *BaseResp) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_BaseResp[fieldId]), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_BaseResp[fieldId]), err) SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *BaseResp) ReadField1(iprot thrift.TProtocol) error { @@ -877,13 +878,13 @@ func (p *BaseResp) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *BaseResp) writeField1(oprot thrift.TProtocol) (err error) { @@ -898,9 +899,9 @@ func (p *BaseResp) writeField1(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } func (p *BaseResp) writeField2(oprot thrift.TProtocol) (err error) { @@ -915,9 +916,9 @@ func (p *BaseResp) writeField2(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } func (p *BaseResp) writeField3(oprot thrift.TProtocol) (err error) { @@ -947,9 +948,9 @@ func (p *BaseResp) writeField3(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } func (p *BaseResp) String() string { diff --git a/pkg/protocol/bthrift/apache/exception.go b/pkg/protocol/bthrift/apache/exception.go deleted file mode 100644 index 2a0a1f67ff..0000000000 --- a/pkg/protocol/bthrift/apache/exception.go +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go - -// Generic Thrift exception -type TException interface { - error -} - -var PrependError = thrift.PrependError diff --git a/pkg/protocol/bthrift/exception.go b/pkg/protocol/bthrift/exception.go index 9854520c87..e3ed256840 100644 --- a/pkg/protocol/bthrift/exception.go +++ b/pkg/protocol/bthrift/exception.go @@ -17,12 +17,13 @@ package bthrift import ( + "errors" "fmt" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) -// ApplicationException represents the application exception decoder for replacing apache.TApplicationException +// ApplicationException is for replacing apache.TApplicationException // it implements ThriftMsgFastCodec interface. type ApplicationException struct { t int32 @@ -177,3 +178,54 @@ func (e *ApplicationException) Error() string { } return fmt.Sprintf("unknown exception type [%d]", e.t) } + +// TransportException is for replacing apache.TransportException +// it implements ThriftMsgFastCodec interface. +type TransportException struct { + ApplicationException // same implementation ... +} + +// NewTransportException ... +func NewTransportException(t int32, m string) *TransportException { + ret := TransportException{} + ret.t = t + ret.m = m + return &ret +} + +// ProtocolException is for replacing apache.ProtocolException +// it implements ThriftMsgFastCodec interface. +type ProtocolException struct { + ApplicationException // same implementation ... +} + +// NewTransportException ... +func NewProtocolException(t int32, m string) *ProtocolException { + ret := ProtocolException{} + ret.t = t + ret.m = m + return &ret +} + +// Generic Thrift exception with TypeId method +type tException interface { + Error() string + TypeId() int32 +} + +// Prepends additional information to an error without losing the Thrift exception interface +func PrependError(prepend string, err error) error { + if t, ok := err.(*TransportException); ok { + return NewTransportException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(*ProtocolException); ok { + return NewProtocolException(t.TypeID(), prepend+err.Error()) + } + if t, ok := err.(*ApplicationException); ok { + return NewApplicationException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(tException); ok { // apache thrift exception? + return NewApplicationException(t.TypeId(), prepend+t.Error()) + } + return errors.New(prepend + err.Error()) +} diff --git a/pkg/protocol/bthrift/exception_test.go b/pkg/protocol/bthrift/exception_test.go index 9a625c9266..574653d2cf 100644 --- a/pkg/protocol/bthrift/exception_test.go +++ b/pkg/protocol/bthrift/exception_test.go @@ -18,6 +18,7 @@ package bthrift import ( "bytes" + "errors" "testing" "github.com/cloudwego/kitex/internal/test" @@ -61,3 +62,41 @@ func TestApplicationException(t *testing.T) { test.Assert(t, ex4.TypeID() == 1) test.Assert(t, ex4.Msg() == "t1") } + +func TestPrependError(t *testing.T) { + var ok bool + ex0 := NewTransportException(1, "world") + err0 := PrependError("hello ", ex0) + ex0, ok = err0.(*TransportException) + test.Assert(t, ok) + test.Assert(t, ex0.TypeID() == 1) + test.Assert(t, ex0.Error() == "hello world") + + ex1 := NewProtocolException(2, "world") + err1 := PrependError("hello ", ex1) + ex1, ok = err1.(*ProtocolException) + test.Assert(t, ok) + test.Assert(t, ex1.TypeID() == 2) + test.Assert(t, ex1.Error() == "hello world") + + ex2 := NewApplicationException(3, "world") + err2 := PrependError("hello ", ex2) + ex2, ok = err2.(*ApplicationException) + test.Assert(t, ok) + test.Assert(t, ex2.TypeID() == 3) + test.Assert(t, ex2.Error() == "hello world") + + err3 := PrependError("hello ", errors.New("world")) + _, ok = err3.(tException) + test.Assert(t, !ok) + test.Assert(t, err3.Error() == "hello world") + + // the code below, it's for compatibility test only. + // it can be removed in the future along with Read/Write method + ex9 := thrift.NewTApplicationException(9, "world") + err9 := PrependError("hello ", ex9) + ex, ok := err9.(tException) + test.Assert(t, ok) + test.Assert(t, ex.TypeId() == 9) + test.Assert(t, ex.Error() == "hello world") +} diff --git a/pkg/remote/trans/netpollmux/control_frame.go b/pkg/remote/trans/netpollmux/control_frame.go index 9c50813fbe..4c060a24d7 100644 --- a/pkg/remote/trans/netpollmux/control_frame.go +++ b/pkg/remote/trans/netpollmux/control_frame.go @@ -25,7 +25,8 @@ package netpollmux import ( "fmt" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type ControlFrame struct{} @@ -66,16 +67,16 @@ func (p *ControlFrame) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) SkipFieldTypeError: - return thrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *ControlFrame) Write(oprot thrift.TProtocol) (err error) { @@ -92,11 +93,11 @@ func (p *ControlFrame) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *ControlFrame) String() string { From 1dccc6769f24e85a6f7ed928602679aead5e8077 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Mon, 8 Jul 2024 10:50:18 +0800 Subject: [PATCH 07/70] feat(thrift): generic fastcodec (#1424) --- internal/mocks/thrift/fast/test.go | 1548 ----------------- internal/mocks/thrift/gen.sh | 5 + internal/mocks/thrift/k-consts.go | 20 + internal/mocks/thrift/{fast => }/k-test.go | 74 +- internal/mocks/thrift/test.go | 380 +--- pkg/generic/binary_test/generic_init.go | 7 +- pkg/generic/binary_test/generic_test.go | 12 +- pkg/generic/binarythrift_codec_test.go | 8 +- pkg/protocol/bthrift/binary.go | 5 +- pkg/protocol/bthrift/exception.go | 6 +- pkg/protocol/bthrift/interface.go | 7 + pkg/remote/codec/thrift/thrift.go | 12 +- pkg/remote/codec/thrift/thrift_data.go | 6 +- pkg/remote/codec/thrift/thrift_data_test.go | 24 +- pkg/remote/codec/thrift/thrift_frugal_test.go | 6 +- pkg/remote/codec/thrift/thrift_test.go | 2 +- pkg/utils/fastthrift/fast_thrift.go | 36 - pkg/utils/fastthrift/fastthrift.go | 82 + ...fast_thrift_test.go => fastthrift_test.go} | 33 +- pkg/utils/kitexutil/kitexutil_test.go | 2 +- 20 files changed, 313 insertions(+), 1962 deletions(-) delete mode 100644 internal/mocks/thrift/fast/test.go create mode 100755 internal/mocks/thrift/gen.sh create mode 100644 internal/mocks/thrift/k-consts.go rename internal/mocks/thrift/{fast => }/k-test.go (97%) delete mode 100644 pkg/utils/fastthrift/fast_thrift.go create mode 100644 pkg/utils/fastthrift/fastthrift.go rename pkg/utils/fastthrift/{fast_thrift_test.go => fastthrift_test.go} (60%) diff --git a/internal/mocks/thrift/fast/test.go b/internal/mocks/thrift/fast/test.go deleted file mode 100644 index ee1928076a..0000000000 --- a/internal/mocks/thrift/fast/test.go +++ /dev/null @@ -1,1548 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Code generated by thriftgo (0.2.1). DO NOT EDIT. - -package fast - -import ( - "context" - "fmt" - "strings" - - "github.com/apache/thrift/lib/go/thrift" -) - -type MockReq struct { - Msg string `thrift:"Msg,1" json:"Msg"` - StrMap map[string]string `thrift:"strMap,2" json:"strMap"` - StrList []string `thrift:"strList,3" json:"strList"` -} - -func NewMockReq() *MockReq { - return &MockReq{} -} - -func (p *MockReq) GetMsg() (v string) { - return p.Msg -} - -func (p *MockReq) GetStrMap() (v map[string]string) { - return p.StrMap -} - -func (p *MockReq) GetStrList() (v []string) { - return p.StrList -} -func (p *MockReq) SetMsg(val string) { - p.Msg = val -} -func (p *MockReq) SetStrMap(val map[string]string) { - p.StrMap = val -} -func (p *MockReq) SetStrList(val []string) { - p.StrList = val -} - -var fieldIDToName_MockReq = map[int16]string{ - 1: "Msg", - 2: "strMap", - 3: "strList", -} - -func (p *MockReq) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRING { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.MAP { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.LIST { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockReq[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockReq) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Msg = v - } - return nil -} - -func (p *MockReq) ReadField2(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.StrMap = make(map[string]string, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - - var _val string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _val = v - } - - p.StrMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *MockReq) ReadField3(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.StrList = make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _elem = v - } - - p.StrList = append(p.StrList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *MockReq) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("MockReq"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockReq) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Msg", thrift.STRING, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Msg); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *MockReq) writeField2(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("strMap", thrift.MAP, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.StrMap)); err != nil { - return err - } - for k, v := range p.StrMap { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *MockReq) writeField3(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("strList", thrift.LIST, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRING, len(p.StrList)); err != nil { - return err - } - for _, v := range p.StrList { - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) -} - -func (p *MockReq) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("MockReq(%+v)", *p) -} - -func (p *MockReq) DeepEqual(ano *MockReq) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Msg) { - return false - } - if !p.Field2DeepEqual(ano.StrMap) { - return false - } - if !p.Field3DeepEqual(ano.StrList) { - return false - } - return true -} - -func (p *MockReq) Field1DeepEqual(src string) bool { - - if strings.Compare(p.Msg, src) != 0 { - return false - } - return true -} -func (p *MockReq) Field2DeepEqual(src map[string]string) bool { - - if len(p.StrMap) != len(src) { - return false - } - for k, v := range p.StrMap { - _src := src[k] - if strings.Compare(v, _src) != 0 { - return false - } - } - return true -} -func (p *MockReq) Field3DeepEqual(src []string) bool { - - if len(p.StrList) != len(src) { - return false - } - for i, v := range p.StrList { - _src := src[i] - if strings.Compare(v, _src) != 0 { - return false - } - } - return true -} - -type Exception struct { - Code int32 `thrift:"code,1" json:"code"` - Msg string `thrift:"msg,255" json:"msg"` -} - -func NewException() *Exception { - return &Exception{} -} - -func (p *Exception) GetCode() (v int32) { - return p.Code -} - -func (p *Exception) GetMsg() (v string) { - return p.Msg -} -func (p *Exception) SetCode(val int32) { - p.Code = val -} -func (p *Exception) SetMsg(val string) { - p.Msg = val -} - -var fieldIDToName_Exception = map[int16]string{ - 1: "code", - 255: "msg", -} - -func (p *Exception) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 255: - if fieldTypeId == thrift.STRING { - if err = p.ReadField255(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Exception[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Exception) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.Code = v - } - return nil -} - -func (p *Exception) ReadField255(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Msg = v - } - return nil -} - -func (p *Exception) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Exception"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField255(oprot); err != nil { - fieldId = 255 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *Exception) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("code", thrift.I32, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.Code); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *Exception) writeField255(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("msg", thrift.STRING, 255); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Msg); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) -} - -func (p *Exception) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("Exception(%+v)", *p) -} -func (p *Exception) Error() string { - return p.String() -} - -func (p *Exception) DeepEqual(ano *Exception) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Code) { - return false - } - if !p.Field255DeepEqual(ano.Msg) { - return false - } - return true -} - -func (p *Exception) Field1DeepEqual(src int32) bool { - - if p.Code != src { - return false - } - return true -} -func (p *Exception) Field255DeepEqual(src string) bool { - - if strings.Compare(p.Msg, src) != 0 { - return false - } - return true -} - -type Mock interface { - Test(ctx context.Context, req *MockReq) (r string, err error) - - ExceptionTest(ctx context.Context, req *MockReq) (r string, err error) -} - -type MockClient struct { - c thrift.TClient -} - -func NewMockClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *MockClient { - return &MockClient{ - c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)), - } -} - -func NewMockClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *MockClient { - return &MockClient{ - c: thrift.NewTStandardClient(iprot, oprot), - } -} - -func NewMockClient(c thrift.TClient) *MockClient { - return &MockClient{ - c: c, - } -} - -func (p *MockClient) Client_() thrift.TClient { - return p.c -} - -func (p *MockClient) Test(ctx context.Context, req *MockReq) (r string, err error) { - var _args MockTestArgs - _args.Req = req - var _result MockTestResult - if err = p.Client_().Call(ctx, "Test", &_args, &_result); err != nil { - return - } - return _result.GetSuccess(), nil -} -func (p *MockClient) ExceptionTest(ctx context.Context, req *MockReq) (r string, err error) { - var _args MockExceptionTestArgs - _args.Req = req - var _result MockExceptionTestResult - if err = p.Client_().Call(ctx, "ExceptionTest", &_args, &_result); err != nil { - return - } - switch { - case _result.Err != nil: - return r, _result.Err - } - return _result.GetSuccess(), nil -} - -type MockProcessor struct { - processorMap map[string]thrift.TProcessorFunction - handler Mock -} - -func (p *MockProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { - p.processorMap[key] = processor -} - -func (p *MockProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { - processor, ok = p.processorMap[key] - return processor, ok -} - -func (p *MockProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { - return p.processorMap -} - -func NewMockProcessor(handler Mock) *MockProcessor { - self := &MockProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} - self.AddToProcessorMap("Test", &mockProcessorTest{handler: handler}) - self.AddToProcessorMap("ExceptionTest", &mockProcessorExceptionTest{handler: handler}) - return self -} -func (p *MockProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - name, _, seqId, err := iprot.ReadMessageBegin() - if err != nil { - return false, err - } - if processor, ok := p.GetProcessorFunction(name); ok { - return processor.Process(ctx, seqId, iprot, oprot) - } - iprot.Skip(thrift.STRUCT) - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) - oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, x -} - -type mockProcessorTest struct { - handler Mock -} - -func (p *mockProcessorTest) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := MockTestArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("Test", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err - } - - iprot.ReadMessageEnd() - var err2 error - result := MockTestResult{} - var retval string - if retval, err2 = p.handler.Test(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Test: "+err2.Error()) - oprot.WriteMessageBegin("Test", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = &retval - } - if err2 = oprot.WriteMessageBegin("Test", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 - } - if err != nil { - return - } - return true, err -} - -type mockProcessorExceptionTest struct { - handler Mock -} - -func (p *mockProcessorExceptionTest) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := MockExceptionTestArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("ExceptionTest", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err - } - - iprot.ReadMessageEnd() - var err2 error - result := MockExceptionTestResult{} - var retval string - if retval, err2 = p.handler.ExceptionTest(ctx, args.Req); err2 != nil { - switch v := err2.(type) { - case *Exception: - result.Err = v - default: - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ExceptionTest: "+err2.Error()) - oprot.WriteMessageBegin("ExceptionTest", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } - } else { - result.Success = &retval - } - if err2 = oprot.WriteMessageBegin("ExceptionTest", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 - } - if err != nil { - return - } - return true, err -} - -type MockTestArgs struct { - Req *MockReq `thrift:"req,1" json:"req"` -} - -func NewMockTestArgs() *MockTestArgs { - return &MockTestArgs{} -} - -var MockTestArgs_Req_DEFAULT *MockReq - -func (p *MockTestArgs) GetReq() (v *MockReq) { - if !p.IsSetReq() { - return MockTestArgs_Req_DEFAULT - } - return p.Req -} -func (p *MockTestArgs) SetReq(val *MockReq) { - p.Req = val -} - -var fieldIDToName_MockTestArgs = map[int16]string{ - 1: "req", -} - -func (p *MockTestArgs) IsSetReq() bool { - return p.Req != nil -} - -func (p *MockTestArgs) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestArgs[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockTestArgs) ReadField1(iprot thrift.TProtocol) error { - p.Req = NewMockReq() - if err := p.Req.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *MockTestArgs) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Test_args"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockTestArgs) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { - goto WriteFieldBeginError - } - if err := p.Req.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *MockTestArgs) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("MockTestArgs(%+v)", *p) -} - -func (p *MockTestArgs) DeepEqual(ano *MockTestArgs) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Req) { - return false - } - return true -} - -func (p *MockTestArgs) Field1DeepEqual(src *MockReq) bool { - - if !p.Req.DeepEqual(src) { - return false - } - return true -} - -type MockTestResult struct { - Success *string `thrift:"success,0,optional" json:"success,omitempty"` -} - -func NewMockTestResult() *MockTestResult { - return &MockTestResult{} -} - -var MockTestResult_Success_DEFAULT string - -func (p *MockTestResult) GetSuccess() (v string) { - if !p.IsSetSuccess() { - return MockTestResult_Success_DEFAULT - } - return *p.Success -} -func (p *MockTestResult) SetSuccess(x interface{}) { - p.Success = x.(*string) -} - -var fieldIDToName_MockTestResult = map[int16]string{ - 0: "success", -} - -func (p *MockTestResult) IsSetSuccess() bool { - return p.Success != nil -} - -func (p *MockTestResult) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 0: - if fieldTypeId == thrift.STRING { - if err = p.ReadField0(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestResult[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockTestResult) ReadField0(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Success = &v - } - return nil -} - -func (p *MockTestResult) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Test_result"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField0(oprot); err != nil { - fieldId = 0 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockTestResult) writeField0(oprot thrift.TProtocol) (err error) { - if p.IsSetSuccess() { - if err = oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(*p.Success); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) -} - -func (p *MockTestResult) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("MockTestResult(%+v)", *p) -} - -func (p *MockTestResult) DeepEqual(ano *MockTestResult) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field0DeepEqual(ano.Success) { - return false - } - return true -} - -func (p *MockTestResult) Field0DeepEqual(src *string) bool { - - if p.Success == src { - return true - } else if p.Success == nil || src == nil { - return false - } - if strings.Compare(*p.Success, *src) != 0 { - return false - } - return true -} - -type MockExceptionTestArgs struct { - Req *MockReq `thrift:"req,1" json:"req"` -} - -func NewMockExceptionTestArgs() *MockExceptionTestArgs { - return &MockExceptionTestArgs{} -} - -var MockExceptionTestArgs_Req_DEFAULT *MockReq - -func (p *MockExceptionTestArgs) GetReq() (v *MockReq) { - if !p.IsSetReq() { - return MockExceptionTestArgs_Req_DEFAULT - } - return p.Req -} -func (p *MockExceptionTestArgs) SetReq(val *MockReq) { - p.Req = val -} - -var fieldIDToName_MockExceptionTestArgs = map[int16]string{ - 1: "req", -} - -func (p *MockExceptionTestArgs) IsSetReq() bool { - return p.Req != nil -} - -func (p *MockExceptionTestArgs) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestArgs[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockExceptionTestArgs) ReadField1(iprot thrift.TProtocol) error { - p.Req = NewMockReq() - if err := p.Req.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *MockExceptionTestArgs) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("ExceptionTest_args"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockExceptionTestArgs) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { - goto WriteFieldBeginError - } - if err := p.Req.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *MockExceptionTestArgs) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("MockExceptionTestArgs(%+v)", *p) -} - -func (p *MockExceptionTestArgs) DeepEqual(ano *MockExceptionTestArgs) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Req) { - return false - } - return true -} - -func (p *MockExceptionTestArgs) Field1DeepEqual(src *MockReq) bool { - - if !p.Req.DeepEqual(src) { - return false - } - return true -} - -type MockExceptionTestResult struct { - Success *string `thrift:"success,0,optional" json:"success,omitempty"` - Err *Exception `thrift:"err,1,optional" json:"err,omitempty"` -} - -func NewMockExceptionTestResult() *MockExceptionTestResult { - return &MockExceptionTestResult{} -} - -var MockExceptionTestResult_Success_DEFAULT string - -func (p *MockExceptionTestResult) GetSuccess() (v string) { - if !p.IsSetSuccess() { - return MockExceptionTestResult_Success_DEFAULT - } - return *p.Success -} - -var MockExceptionTestResult_Err_DEFAULT *Exception - -func (p *MockExceptionTestResult) GetErr() (v *Exception) { - if !p.IsSetErr() { - return MockExceptionTestResult_Err_DEFAULT - } - return p.Err -} -func (p *MockExceptionTestResult) SetSuccess(x interface{}) { - p.Success = x.(*string) -} -func (p *MockExceptionTestResult) SetErr(val *Exception) { - p.Err = val -} - -var fieldIDToName_MockExceptionTestResult = map[int16]string{ - 0: "success", - 1: "err", -} - -func (p *MockExceptionTestResult) IsSetSuccess() bool { - return p.Success != nil -} - -func (p *MockExceptionTestResult) IsSetErr() bool { - return p.Err != nil -} - -func (p *MockExceptionTestResult) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 0: - if fieldTypeId == thrift.STRING { - if err = p.ReadField0(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 1: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestResult[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockExceptionTestResult) ReadField0(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Success = &v - } - return nil -} - -func (p *MockExceptionTestResult) ReadField1(iprot thrift.TProtocol) error { - p.Err = NewException() - if err := p.Err.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *MockExceptionTestResult) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("ExceptionTest_result"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField0(oprot); err != nil { - fieldId = 0 - goto WriteFieldError - } - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockExceptionTestResult) writeField0(oprot thrift.TProtocol) (err error) { - if p.IsSetSuccess() { - if err = oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(*p.Success); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) -} - -func (p *MockExceptionTestResult) writeField1(oprot thrift.TProtocol) (err error) { - if p.IsSetErr() { - if err = oprot.WriteFieldBegin("err", thrift.STRUCT, 1); err != nil { - goto WriteFieldBeginError - } - if err := p.Err.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *MockExceptionTestResult) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("MockExceptionTestResult(%+v)", *p) -} - -func (p *MockExceptionTestResult) DeepEqual(ano *MockExceptionTestResult) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field0DeepEqual(ano.Success) { - return false - } - if !p.Field1DeepEqual(ano.Err) { - return false - } - return true -} - -func (p *MockExceptionTestResult) Field0DeepEqual(src *string) bool { - - if p.Success == src { - return true - } else if p.Success == nil || src == nil { - return false - } - if strings.Compare(*p.Success, *src) != 0 { - return false - } - return true -} -func (p *MockExceptionTestResult) Field1DeepEqual(src *Exception) bool { - - if !p.Err.DeepEqual(src) { - return false - } - return true -} diff --git a/internal/mocks/thrift/gen.sh b/internal/mocks/thrift/gen.sh new file mode 100755 index 0000000000..df178c759b --- /dev/null +++ b/internal/mocks/thrift/gen.sh @@ -0,0 +1,5 @@ +#! /bin/bash + +kitex -module github.com/cloudwego/kitex -gen-path .. ./test.thrift + +rm -rf ./mock # not in use, rm it diff --git a/internal/mocks/thrift/k-consts.go b/internal/mocks/thrift/k-consts.go new file mode 100644 index 0000000000..84f8d35254 --- /dev/null +++ b/internal/mocks/thrift/k-consts.go @@ -0,0 +1,20 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +// KitexUnusedProtection is used to prevent 'imported and not used' error. +var KitexUnusedProtection = struct{}{} diff --git a/internal/mocks/thrift/fast/k-test.go b/internal/mocks/thrift/k-test.go similarity index 97% rename from internal/mocks/thrift/fast/k-test.go rename to internal/mocks/thrift/k-test.go index c31f69fd23..2020f27c4d 100644 --- a/internal/mocks/thrift/fast/k-test.go +++ b/internal/mocks/thrift/k-test.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -// Code generated by Kitex v0.4.3. DO NOT EDIT. +// Code generated by Kitex v0.10.1. DO NOT EDIT. -package fast +package thrift import ( "bytes" @@ -141,14 +141,16 @@ ReadStructEndError: func (p *MockReq) FastReadField1(buf []byte) (int, error) { offset := 0 + var _field string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Msg = v + _field = v } + p.Msg = _field return offset, nil } @@ -160,7 +162,7 @@ func (p *MockReq) FastReadField2(buf []byte) (int, error) { if err != nil { return offset, err } - p.StrMap = make(map[string]string, size) + _field := make(map[string]string, size) for i := 0; i < size; i++ { var _key string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -182,13 +184,14 @@ func (p *MockReq) FastReadField2(buf []byte) (int, error) { } - p.StrMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.StrMap = _field return offset, nil } @@ -200,7 +203,7 @@ func (p *MockReq) FastReadField3(buf []byte) (int, error) { if err != nil { return offset, err } - p.StrList = make([]string, 0, size) + _field := make([]string, 0, size) for i := 0; i < size; i++ { var _elem string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -212,13 +215,14 @@ func (p *MockReq) FastReadField3(buf []byte) (int, error) { } - p.StrList = append(p.StrList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.StrList = _field return offset, nil } @@ -257,7 +261,6 @@ func (p *MockReq) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Msg", thrift.STRING, 1) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Msg) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -270,11 +273,8 @@ func (p *MockReq) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) var length int for k, v := range p.StrMap { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) offset += bthrift.Binary.WriteMapEnd(buf[offset:]) @@ -291,7 +291,6 @@ func (p *MockReq) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWriter) for _, v := range p.StrList { length++ offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -303,7 +302,6 @@ func (p *MockReq) field1Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Msg", thrift.STRING, 1) l += bthrift.Binary.StringLengthNocopy(p.Msg) - l += bthrift.Binary.FieldEndLength() return l } @@ -315,9 +313,7 @@ func (p *MockReq) field2Length() int { for k, v := range p.StrMap { l += bthrift.Binary.StringLengthNocopy(k) - l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.MapEndLength() l += bthrift.Binary.FieldEndLength() @@ -330,7 +326,6 @@ func (p *MockReq) field3Length() int { l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.StrList)) for _, v := range p.StrList { l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.ListEndLength() l += bthrift.Binary.FieldEndLength() @@ -425,28 +420,32 @@ ReadStructEndError: func (p *Exception) FastReadField1(buf []byte) (int, error) { offset := 0 + var _field int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Code = v + _field = v } + p.Code = _field return offset, nil } func (p *Exception) FastReadField255(buf []byte) (int, error) { offset := 0 + var _field string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Msg = v + _field = v } + p.Msg = _field return offset, nil } @@ -483,7 +482,6 @@ func (p *Exception) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWrite offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "code", thrift.I32, 1) offset += bthrift.Binary.WriteI32(buf[offset:], p.Code) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -492,7 +490,6 @@ func (p *Exception) fastWriteField255(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "msg", thrift.STRING, 255) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Msg) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -501,7 +498,6 @@ func (p *Exception) field1Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("code", thrift.I32, 1) l += bthrift.Binary.I32Length(p.Code) - l += bthrift.Binary.FieldEndLength() return l } @@ -510,7 +506,6 @@ func (p *Exception) field255Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("msg", thrift.STRING, 255) l += bthrift.Binary.StringLengthNocopy(p.Msg) - l += bthrift.Binary.FieldEndLength() return l } @@ -588,14 +583,13 @@ ReadStructEndError: func (p *MockTestArgs) FastReadField1(buf []byte) (int, error) { offset := 0 - - tmp := NewMockReq() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewMockReq() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.Req = tmp + p.Req = _field return offset, nil } @@ -716,13 +710,15 @@ ReadStructEndError: func (p *MockTestResult) FastReadField0(buf []byte) (int, error) { offset := 0 + var _field *string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Success = &v + _field = &v } + p.Success = _field return offset, nil } @@ -758,7 +754,6 @@ func (p *MockTestResult) fastWriteField0(buf []byte, binaryWriter bthrift.Binary if p.IsSetSuccess() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "success", thrift.STRING, 0) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, *p.Success) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -769,7 +764,6 @@ func (p *MockTestResult) field0Length() int { if p.IsSetSuccess() { l += bthrift.Binary.FieldBeginLength("success", thrift.STRING, 0) l += bthrift.Binary.StringLengthNocopy(*p.Success) - l += bthrift.Binary.FieldEndLength() } return l @@ -848,14 +842,13 @@ ReadStructEndError: func (p *MockExceptionTestArgs) FastReadField1(buf []byte) (int, error) { offset := 0 - - tmp := NewMockReq() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewMockReq() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.Req = tmp + p.Req = _field return offset, nil } @@ -990,26 +983,27 @@ ReadStructEndError: func (p *MockExceptionTestResult) FastReadField0(buf []byte) (int, error) { offset := 0 + var _field *string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Success = &v + _field = &v } + p.Success = _field return offset, nil } func (p *MockExceptionTestResult) FastReadField1(buf []byte) (int, error) { offset := 0 - - tmp := NewException() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewException() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.Err = tmp + p.Err = _field return offset, nil } @@ -1047,7 +1041,6 @@ func (p *MockExceptionTestResult) fastWriteField0(buf []byte, binaryWriter bthri if p.IsSetSuccess() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "success", thrift.STRING, 0) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, *p.Success) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -1068,7 +1061,6 @@ func (p *MockExceptionTestResult) field0Length() int { if p.IsSetSuccess() { l += bthrift.Binary.FieldBeginLength("success", thrift.STRING, 0) l += bthrift.Binary.StringLengthNocopy(*p.Success) - l += bthrift.Binary.FieldEndLength() } return l diff --git a/internal/mocks/thrift/test.go b/internal/mocks/thrift/test.go index 4d8c4eb26c..6eba70afe6 100644 --- a/internal/mocks/thrift/test.go +++ b/internal/mocks/thrift/test.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,28 +14,30 @@ * limitations under the License. */ -// Code generated by thriftgo (0.2.1). DO NOT EDIT. +// Code generated by thriftgo (0.3.13). DO NOT EDIT. package thrift import ( "context" "fmt" - "strings" - "github.com/apache/thrift/lib/go/thrift" + "strings" ) type MockReq struct { - Msg string `thrift:"Msg,1" json:"Msg"` - StrMap map[string]string `thrift:"strMap,2" json:"strMap"` - StrList []string `thrift:"strList,3" json:"strList"` + Msg string `thrift:"Msg,1" frugal:"1,default,string" json:"Msg"` + StrMap map[string]string `thrift:"strMap,2" frugal:"2,default,map" json:"strMap"` + StrList []string `thrift:"strList,3" frugal:"3,default,list" json:"strList"` } func NewMockReq() *MockReq { return &MockReq{} } +func (p *MockReq) InitDefault() { +} + func (p *MockReq) GetMsg() (v string) { return p.Msg } @@ -87,37 +89,30 @@ func (p *MockReq) Read(iprot thrift.TProtocol) (err error) { if err = p.ReadField1(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } case 2: if fieldTypeId == thrift.MAP { if err = p.ReadField2(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } case 3: if fieldTypeId == thrift.LIST { if err = p.ReadField3(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } } - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -143,20 +138,22 @@ ReadStructEndError: } func (p *MockReq) ReadField1(iprot thrift.TProtocol) error { + + var _field string if v, err := iprot.ReadString(); err != nil { return err } else { - p.Msg = v + _field = v } + p.Msg = _field return nil } - func (p *MockReq) ReadField2(iprot thrift.TProtocol) error { _, _, size, err := iprot.ReadMapBegin() if err != nil { return err } - p.StrMap = make(map[string]string, size) + _field := make(map[string]string, size) for i := 0; i < size; i++ { var _key string if v, err := iprot.ReadString(); err != nil { @@ -172,21 +169,22 @@ func (p *MockReq) ReadField2(iprot thrift.TProtocol) error { _val = v } - p.StrMap[_key] = _val + _field[_key] = _val } if err := iprot.ReadMapEnd(); err != nil { return err } + p.StrMap = _field return nil } - func (p *MockReq) ReadField3(iprot thrift.TProtocol) error { _, size, err := iprot.ReadListBegin() if err != nil { return err } - p.StrList = make([]string, 0, size) + _field := make([]string, 0, size) for i := 0; i < size; i++ { + var _elem string if v, err := iprot.ReadString(); err != nil { return err @@ -194,11 +192,12 @@ func (p *MockReq) ReadField3(iprot thrift.TProtocol) error { _elem = v } - p.StrList = append(p.StrList, _elem) + _field = append(_field, _elem) } if err := iprot.ReadListEnd(); err != nil { return err } + p.StrList = _field return nil } @@ -220,7 +219,6 @@ func (p *MockReq) Write(oprot thrift.TProtocol) (err error) { fieldId = 3 goto WriteFieldError } - } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -264,11 +262,9 @@ func (p *MockReq) writeField2(oprot thrift.TProtocol) (err error) { return err } for k, v := range p.StrMap { - if err := oprot.WriteString(k); err != nil { return err } - if err := oprot.WriteString(v); err != nil { return err } @@ -316,6 +312,7 @@ func (p *MockReq) String() string { return "" } return fmt.Sprintf("MockReq(%+v)", *p) + } func (p *MockReq) DeepEqual(ano *MockReq) bool { @@ -371,14 +368,17 @@ func (p *MockReq) Field3DeepEqual(src []string) bool { } type Exception struct { - Code int32 `thrift:"code,1" json:"code"` - Msg string `thrift:"msg,255" json:"msg"` + Code int32 `thrift:"code,1" frugal:"1,default,i32" json:"code"` + Msg string `thrift:"msg,255" frugal:"255,default,string" json:"msg"` } func NewException() *Exception { return &Exception{} } +func (p *Exception) InitDefault() { +} + func (p *Exception) GetCode() (v int32) { return p.Code } @@ -422,27 +422,22 @@ func (p *Exception) Read(iprot thrift.TProtocol) (err error) { if err = p.ReadField1(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } case 255: if fieldTypeId == thrift.STRING { if err = p.ReadField255(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } } - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -468,20 +463,25 @@ ReadStructEndError: } func (p *Exception) ReadField1(iprot thrift.TProtocol) error { + + var _field int32 if v, err := iprot.ReadI32(); err != nil { return err } else { - p.Code = v + _field = v } + p.Code = _field return nil } - func (p *Exception) ReadField255(iprot thrift.TProtocol) error { + + var _field string if v, err := iprot.ReadString(); err != nil { return err } else { - p.Msg = v + _field = v } + p.Msg = _field return nil } @@ -499,7 +499,6 @@ func (p *Exception) Write(oprot thrift.TProtocol) (err error) { fieldId = 255 goto WriteFieldError } - } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -557,6 +556,7 @@ func (p *Exception) String() string { return "" } return fmt.Sprintf("Exception(%+v)", *p) + } func (p *Exception) Error() string { return p.String() @@ -598,206 +598,17 @@ type Mock interface { ExceptionTest(ctx context.Context, req *MockReq) (r string, err error) } -type MockClient struct { - c thrift.TClient -} - -func NewMockClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *MockClient { - return &MockClient{ - c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)), - } -} - -func NewMockClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *MockClient { - return &MockClient{ - c: thrift.NewTStandardClient(iprot, oprot), - } -} - -func NewMockClient(c thrift.TClient) *MockClient { - return &MockClient{ - c: c, - } -} - -func (p *MockClient) Client_() thrift.TClient { - return p.c -} - -func (p *MockClient) Test(ctx context.Context, req *MockReq) (r string, err error) { - var _args MockTestArgs - _args.Req = req - var _result MockTestResult - if err = p.Client_().Call(ctx, "Test", &_args, &_result); err != nil { - return - } - return _result.GetSuccess(), nil -} -func (p *MockClient) ExceptionTest(ctx context.Context, req *MockReq) (r string, err error) { - var _args MockExceptionTestArgs - _args.Req = req - var _result MockExceptionTestResult - if err = p.Client_().Call(ctx, "ExceptionTest", &_args, &_result); err != nil { - return - } - switch { - case _result.Err != nil: - return r, _result.Err - } - return _result.GetSuccess(), nil -} - -type MockProcessor struct { - processorMap map[string]thrift.TProcessorFunction - handler Mock -} - -func (p *MockProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { - p.processorMap[key] = processor -} - -func (p *MockProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { - processor, ok = p.processorMap[key] - return processor, ok -} - -func (p *MockProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { - return p.processorMap -} - -func NewMockProcessor(handler Mock) *MockProcessor { - self := &MockProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} - self.AddToProcessorMap("Test", &mockProcessorTest{handler: handler}) - self.AddToProcessorMap("ExceptionTest", &mockProcessorExceptionTest{handler: handler}) - return self -} -func (p *MockProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - name, _, seqId, err := iprot.ReadMessageBegin() - if err != nil { - return false, err - } - if processor, ok := p.GetProcessorFunction(name); ok { - return processor.Process(ctx, seqId, iprot, oprot) - } - iprot.Skip(thrift.STRUCT) - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) - oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, x -} - -type mockProcessorTest struct { - handler Mock -} - -func (p *mockProcessorTest) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := MockTestArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("Test", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err - } - - iprot.ReadMessageEnd() - var err2 error - result := MockTestResult{} - var retval string - if retval, err2 = p.handler.Test(ctx, args.Req); err2 != nil { - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Test: "+err2.Error()) - oprot.WriteMessageBegin("Test", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } else { - result.Success = &retval - } - if err2 = oprot.WriteMessageBegin("Test", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 - } - if err != nil { - return - } - return true, err -} - -type mockProcessorExceptionTest struct { - handler Mock -} - -func (p *mockProcessorExceptionTest) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { - args := MockExceptionTestArgs{} - if err = args.Read(iprot); err != nil { - iprot.ReadMessageEnd() - x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) - oprot.WriteMessageBegin("ExceptionTest", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return false, err - } - - iprot.ReadMessageEnd() - var err2 error - result := MockExceptionTestResult{} - var retval string - if retval, err2 = p.handler.ExceptionTest(ctx, args.Req); err2 != nil { - switch v := err2.(type) { - case *Exception: - result.Err = v - default: - x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing ExceptionTest: "+err2.Error()) - oprot.WriteMessageBegin("ExceptionTest", thrift.EXCEPTION, seqId) - x.Write(oprot) - oprot.WriteMessageEnd() - oprot.Flush(ctx) - return true, err2 - } - } else { - result.Success = &retval - } - if err2 = oprot.WriteMessageBegin("ExceptionTest", thrift.REPLY, seqId); err2 != nil { - err = err2 - } - if err2 = result.Write(oprot); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { - err = err2 - } - if err2 = oprot.Flush(ctx); err == nil && err2 != nil { - err = err2 - } - if err != nil { - return - } - return true, err -} - type MockTestArgs struct { - Req *MockReq `thrift:"req,1" json:"req"` + Req *MockReq `thrift:"req,1" frugal:"1,default,MockReq" json:"req"` } func NewMockTestArgs() *MockTestArgs { return &MockTestArgs{} } +func (p *MockTestArgs) InitDefault() { +} + var MockTestArgs_Req_DEFAULT *MockReq func (p *MockTestArgs) GetReq() (v *MockReq) { @@ -810,10 +621,6 @@ func (p *MockTestArgs) SetReq(val *MockReq) { p.Req = val } -func (p *MockTestArgs) GetFirstArgument() (interface{}) { - return p.Req -} - var fieldIDToName_MockTestArgs = map[int16]string{ 1: "req", } @@ -846,17 +653,14 @@ func (p *MockTestArgs) Read(iprot thrift.TProtocol) (err error) { if err = p.ReadField1(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } } - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -882,10 +686,11 @@ ReadStructEndError: } func (p *MockTestArgs) ReadField1(iprot thrift.TProtocol) error { - p.Req = NewMockReq() - if err := p.Req.Read(iprot); err != nil { + _field := NewMockReq() + if err := _field.Read(iprot); err != nil { return err } + p.Req = _field return nil } @@ -899,7 +704,6 @@ func (p *MockTestArgs) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } - } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -940,6 +744,7 @@ func (p *MockTestArgs) String() string { return "" } return fmt.Sprintf("MockTestArgs(%+v)", *p) + } func (p *MockTestArgs) DeepEqual(ano *MockTestArgs) bool { @@ -963,13 +768,16 @@ func (p *MockTestArgs) Field1DeepEqual(src *MockReq) bool { } type MockTestResult struct { - Success *string `thrift:"success,0,optional" json:"success,omitempty"` + Success *string `thrift:"success,0,optional" frugal:"0,optional,string" json:"success,omitempty"` } func NewMockTestResult() *MockTestResult { return &MockTestResult{} } +func (p *MockTestResult) InitDefault() { +} + var MockTestResult_Success_DEFAULT string func (p *MockTestResult) GetSuccess() (v string) { @@ -982,10 +790,6 @@ func (p *MockTestResult) SetSuccess(x interface{}) { p.Success = x.(*string) } -func (p *MockTestResult) GetResult() interface{} { - return p.Success -} - var fieldIDToName_MockTestResult = map[int16]string{ 0: "success", } @@ -1018,17 +822,14 @@ func (p *MockTestResult) Read(iprot thrift.TProtocol) (err error) { if err = p.ReadField0(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } } - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -1054,11 +855,14 @@ ReadStructEndError: } func (p *MockTestResult) ReadField0(iprot thrift.TProtocol) error { + + var _field *string if v, err := iprot.ReadString(); err != nil { return err } else { - p.Success = &v + _field = &v } + p.Success = _field return nil } @@ -1072,7 +876,6 @@ func (p *MockTestResult) Write(oprot thrift.TProtocol) (err error) { fieldId = 0 goto WriteFieldError } - } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1115,6 +918,7 @@ func (p *MockTestResult) String() string { return "" } return fmt.Sprintf("MockTestResult(%+v)", *p) + } func (p *MockTestResult) DeepEqual(ano *MockTestResult) bool { @@ -1143,13 +947,16 @@ func (p *MockTestResult) Field0DeepEqual(src *string) bool { } type MockExceptionTestArgs struct { - Req *MockReq `thrift:"req,1" json:"req"` + Req *MockReq `thrift:"req,1" frugal:"1,default,MockReq" json:"req"` } func NewMockExceptionTestArgs() *MockExceptionTestArgs { return &MockExceptionTestArgs{} } +func (p *MockExceptionTestArgs) InitDefault() { +} + var MockExceptionTestArgs_Req_DEFAULT *MockReq func (p *MockExceptionTestArgs) GetReq() (v *MockReq) { @@ -1194,17 +1001,14 @@ func (p *MockExceptionTestArgs) Read(iprot thrift.TProtocol) (err error) { if err = p.ReadField1(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } } - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -1230,10 +1034,11 @@ ReadStructEndError: } func (p *MockExceptionTestArgs) ReadField1(iprot thrift.TProtocol) error { - p.Req = NewMockReq() - if err := p.Req.Read(iprot); err != nil { + _field := NewMockReq() + if err := _field.Read(iprot); err != nil { return err } + p.Req = _field return nil } @@ -1247,7 +1052,6 @@ func (p *MockExceptionTestArgs) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } - } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1288,6 +1092,7 @@ func (p *MockExceptionTestArgs) String() string { return "" } return fmt.Sprintf("MockExceptionTestArgs(%+v)", *p) + } func (p *MockExceptionTestArgs) DeepEqual(ano *MockExceptionTestArgs) bool { @@ -1311,14 +1116,17 @@ func (p *MockExceptionTestArgs) Field1DeepEqual(src *MockReq) bool { } type MockExceptionTestResult struct { - Success *string `thrift:"success,0,optional" json:"success,omitempty"` - Err *Exception `thrift:"err,1,optional" json:"err,omitempty"` + Success *string `thrift:"success,0,optional" frugal:"0,optional,string" json:"success,omitempty"` + Err *Exception `thrift:"err,1,optional" frugal:"1,optional,Exception" json:"err,omitempty"` } func NewMockExceptionTestResult() *MockExceptionTestResult { return &MockExceptionTestResult{} } +func (p *MockExceptionTestResult) InitDefault() { +} + var MockExceptionTestResult_Success_DEFAULT string func (p *MockExceptionTestResult) GetSuccess() (v string) { @@ -1380,27 +1188,22 @@ func (p *MockExceptionTestResult) Read(iprot thrift.TProtocol) (err error) { if err = p.ReadField0(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } case 1: if fieldTypeId == thrift.STRUCT { if err = p.ReadField1(iprot); err != nil { goto ReadFieldError } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } } - if err = iprot.ReadFieldEnd(); err != nil { goto ReadFieldEndError } @@ -1426,19 +1229,22 @@ ReadStructEndError: } func (p *MockExceptionTestResult) ReadField0(iprot thrift.TProtocol) error { + + var _field *string if v, err := iprot.ReadString(); err != nil { return err } else { - p.Success = &v + _field = &v } + p.Success = _field return nil } - func (p *MockExceptionTestResult) ReadField1(iprot thrift.TProtocol) error { - p.Err = NewException() - if err := p.Err.Read(iprot); err != nil { + _field := NewException() + if err := _field.Read(iprot); err != nil { return err } + p.Err = _field return nil } @@ -1456,7 +1262,6 @@ func (p *MockExceptionTestResult) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } - } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1518,6 +1323,7 @@ func (p *MockExceptionTestResult) String() string { return "" } return fmt.Sprintf("MockExceptionTestResult(%+v)", *p) + } func (p *MockExceptionTestResult) DeepEqual(ano *MockExceptionTestResult) bool { diff --git a/pkg/generic/binary_test/generic_init.go b/pkg/generic/binary_test/generic_init.go index d7f5bd4418..6b5ae69fe2 100644 --- a/pkg/generic/binary_test/generic_init.go +++ b/pkg/generic/binary_test/generic_init.go @@ -34,7 +34,7 @@ import ( thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/transmeta" - "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/pkg/utils/fastthrift" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" "github.com/cloudwego/kitex/transport" @@ -99,11 +99,10 @@ type GenericServiceMockImpl struct{} // GenericCall ... func (g *GenericServiceMockImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) { - rc := utils.NewThriftMessageCodec() buf := request.([]byte) var args2 kt.MockTestArgs - mth, seqID, err := rc.Decode(buf, &args2) + mth, seqID, err := fastthrift.UnmarshalMsg(buf, &args2) if err != nil { return nil, err } @@ -116,7 +115,7 @@ func (g *GenericServiceMockImpl) GenericCall(ctx context.Context, method string, result := kt.NewMockTestResult() result.Success = &resp - buf, err = rc.Encode(mth, thrift.REPLY, seqID, result) + buf, err = fastthrift.MarshalMsg(mth, fastthrift.REPLY, seqID, result) return buf, err } diff --git a/pkg/generic/binary_test/generic_test.go b/pkg/generic/binary_test/generic_test.go index f5d7613b64..57d0c34b56 100644 --- a/pkg/generic/binary_test/generic_test.go +++ b/pkg/generic/binary_test/generic_test.go @@ -34,7 +34,7 @@ import ( "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/kerrors" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/pkg/utils/fastthrift" "github.com/cloudwego/kitex/server" ) @@ -111,8 +111,7 @@ func rawThriftBinaryMockReq(t *testing.T) { args.Req = req // encode - rc := utils.NewThriftMessageCodec() - buf, err := rc.Encode("Test", thrift.CALL, 100, args) + buf, err := fastthrift.MarshalMsg("Test", fastthrift.CALL, 100, args) test.Assert(t, err == nil, err) resp, err := cli.GenericCall(context.Background(), "Test", buf) @@ -121,7 +120,7 @@ func rawThriftBinaryMockReq(t *testing.T) { // decode buf = resp.([]byte) var result kt.MockTestResult - method, seqID, err := rc.Decode(buf, &result) + method, seqID, err := fastthrift.UnmarshalMsg(buf, &result) test.Assert(t, err == nil, err) test.Assert(t, method == "Test", method) test.Assert(t, seqID != 100, seqID) @@ -148,8 +147,7 @@ func rawThriftBinary2NormalServer(t *testing.T) { args.Req = req // encode - rc := utils.NewThriftMessageCodec() - buf, err := rc.Encode("Test", thrift.CALL, 100, args) + buf, err := fastthrift.MarshalMsg("Test", fastthrift.CALL, 100, args) test.Assert(t, err == nil, err) resp, err := cli.GenericCall(context.Background(), "Test", buf, callopt.WithRPCTimeout(100*time.Second)) @@ -158,7 +156,7 @@ func rawThriftBinary2NormalServer(t *testing.T) { // decode buf = resp.([]byte) var result kt.MockTestResult - method, seqID, err := rc.Decode(buf, &result) + method, seqID, err := fastthrift.UnmarshalMsg(buf, &result) test.Assert(t, err == nil, err) test.Assert(t, method == "Test", method) // seqID会在kitex中覆盖,避免TTHeader和Payload codec 不一致问题 diff --git a/pkg/generic/binarythrift_codec_test.go b/pkg/generic/binarythrift_codec_test.go index 5393b16efc..0a211d3994 100644 --- a/pkg/generic/binarythrift_codec_test.go +++ b/pkg/generic/binarythrift_codec_test.go @@ -22,11 +22,10 @@ import ( kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/pkg/utils/fastthrift" ) func TestBinaryThriftCodec(t *testing.T) { @@ -34,8 +33,7 @@ func TestBinaryThriftCodec(t *testing.T) { args := kt.NewMockTestArgs() args.Req = req // encode - rc := utils.NewThriftMessageCodec() - buf, err := rc.Encode("mock", thrift.CALL, 100, args) + buf, err := fastthrift.MarshalMsg("mock", fastthrift.CALL, 100, args) test.Assert(t, err == nil, err) btc := &binaryThriftCodec{thriftCodec} @@ -92,7 +90,7 @@ func TestBinaryThriftCodec(t *testing.T) { test.Assert(t, seqID == 1, seqID) var req2 kt.MockTestArgs - method, seqID2, err2 := rc.Decode(reqBuf, &req2) + method, seqID2, err2 := fastthrift.UnmarshalMsg(reqBuf, &req2) test.Assert(t, err2 == nil, err) test.Assert(t, seqID2 == 1, seqID) test.Assert(t, method == "mock", method) diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index 4eaba2b900..e10dc2ecf8 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -172,9 +172,8 @@ func (binaryProtocol) WriteBinaryNocopy(buf []byte, binaryWriter BinaryWriter, v return l + len(value) } -func (binaryProtocol) MessageBeginLength(name string, typeID thrift.TMessageType, seqid int32) int { - version := uint32(thrift.VERSION_1) | uint32(typeID) - return Binary.I32Length(int32(version)) + Binary.StringLength(name) + Binary.I32Length(seqid) +func (binaryProtocol) MessageBeginLength(name string, _ thrift.TMessageType, _ int32) int { + return 4 + Binary.StringLength(name) + 4 } func (binaryProtocol) MessageEndLength() int { diff --git a/pkg/protocol/bthrift/exception.go b/pkg/protocol/bthrift/exception.go index e3ed256840..598a47ec9d 100644 --- a/pkg/protocol/bthrift/exception.go +++ b/pkg/protocol/bthrift/exception.go @@ -24,7 +24,7 @@ import ( ) // ApplicationException is for replacing apache.TApplicationException -// it implements ThriftMsgFastCodec interface. +// it implements ThriftFastCodec interface. type ApplicationException struct { t int32 m string @@ -180,7 +180,7 @@ func (e *ApplicationException) Error() string { } // TransportException is for replacing apache.TransportException -// it implements ThriftMsgFastCodec interface. +// it implements ThriftFastCodec interface. type TransportException struct { ApplicationException // same implementation ... } @@ -194,7 +194,7 @@ func NewTransportException(t int32, m string) *TransportException { } // ProtocolException is for replacing apache.ProtocolException -// it implements ThriftMsgFastCodec interface. +// it implements ThriftFastCodec interface. type ProtocolException struct { ApplicationException // same implementation ... } diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index f3922b1e64..62bbb23dc6 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -26,6 +26,13 @@ type BinaryWriter interface { WriteDirect(b []byte, remainCap int) error } +// ThriftFastCodec represents the interface of thrift fastcodec generated structs +type ThriftFastCodec interface { + BLength() int + FastWriteNocopy(buf []byte, binaryWriter BinaryWriter) int + FastRead(buf []byte) (int, error) +} + // BTProtocol . type BTProtocol interface { WriteMessageBegin(buf []byte, name string, typeID thrift.TMessageType, seqid int32) int diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 0d65f224c7..9a9bab0b02 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -111,7 +111,7 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re // encode with FastWrite if c.CodecType&FastWrite != 0 { - if msg, ok := data.(ThriftMsgFastCodec); ok { + if msg, ok := data.(bthrift.ThriftFastCodec); ok { return encodeFastThrift(out, methodName, msgType, seqID, msg) } } @@ -131,7 +131,7 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re } // encodeFastThrift encode with the FastCodec way -func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg ThriftMsgFastCodec) error { +func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg bthrift.ThriftFastCodec) error { nw, _ := out.(remote.NocopyWrite) // nocopy write is a special implementation of linked buffer, only bytebuffer implement NocopyWrite do FastWrite msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) @@ -259,11 +259,9 @@ type MessageReaderWithMethodWithContext interface { Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error } -type ThriftMsgFastCodec interface { - BLength() int - FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int - FastRead(buf []byte) (int, error) -} +// ThriftMsgFastCodec ... +// Deprecated: use `bthrift.ThriftFastCodec` +type ThriftMsgFastCodec = bthrift.ThriftFastCodec func getValidData(methodName string, message remote.Message) (interface{}, error) { if err := codec.NewDataIfNeeded(methodName, message); err != nil { diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 548b634a68..7d42784d7a 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -51,7 +51,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ // encode with FastWrite if c.CodecType&FastWrite != 0 { - if msg, ok := data.(ThriftMsgFastCodec); ok { + if msg, ok := data.(bthrift.ThriftFastCodec); ok { payloadSize := msg.BLength() payload := mcache.Malloc(payloadSize) msg.FastWriteNocopy(payload, nil) @@ -146,12 +146,12 @@ func (c thriftCodec) fastMessageUnmarshalAvailable(data interface{}, payloadLen if payloadLen == 0 && c.CodecType&EnableSkipDecoder == 0 { return false } - _, ok := data.(ThriftMsgFastCodec) + _, ok := data.(bthrift.ThriftFastCodec) return ok } func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { - msg := data.(ThriftMsgFastCodec) + msg := data.(bthrift.ThriftFastCodec) if dataLen > 0 { buf, err := tProt.next(dataLen) if err != nil { diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index 47f9ce5dcd..7ba4a5ae57 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -22,7 +22,7 @@ import ( "strings" "testing" - "github.com/cloudwego/kitex/internal/mocks/thrift/fast" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" @@ -30,7 +30,7 @@ import ( ) var ( - mockReq = &fast.MockReq{ + mockReq = &mocks.MockReq{ Msg: "hello", } mockReqThrift = []byte{ @@ -77,26 +77,26 @@ func TestMarshalThriftData(t *testing.T) { func Test_decodeBasicThriftData(t *testing.T) { t.Run("empty-input", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{})) err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) test.Assert(t, err != nil, err) }) t.Run("invalid-input", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{0xff})) err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) test.Assert(t, err != nil, err) }) t.Run("normal-input", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) checkDecodeResult(t, err, req) }) } -func checkDecodeResult(t *testing.T, err error, req *fast.MockReq) { +func checkDecodeResult(t *testing.T, err error, req *mocks.MockReq) { test.Assert(t, err == nil, err) test.Assert(t, req.Msg == mockReq.Msg, req.Msg, mockReq.Msg) test.Assert(t, len(req.StrMap) == 0, req.StrMap) @@ -105,17 +105,17 @@ func checkDecodeResult(t *testing.T, err error, req *fast.MockReq) { func TestUnmarshalThriftData(t *testing.T) { t.Run("NoCodec(=FastCodec)", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} err := UnmarshalThriftData(context.Background(), nil, "mock", mockReqThrift, req) checkDecodeResult(t, err, req) }) t.Run("FastCodec", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(FastRead|FastWrite), "mock", mockReqThrift, req) checkDecodeResult(t, err, req) }) t.Run("BasicCodec", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, req) checkDecodeResult(t, err, req) }) @@ -124,13 +124,13 @@ func TestUnmarshalThriftData(t *testing.T) { func TestThriftCodec_unmarshalThriftData(t *testing.T) { t.Run("FastCodec with SkipDecoder enabled", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} codec := &thriftCodec{FastRead | EnableSkipDecoder} tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) - checkDecodeResult(t, err, &fast.MockReq{ + checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, StrMap: req.StrMap, @@ -138,7 +138,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { }) t.Run("FastCodec with SkipDecoder enabled, failed in using SkipDecoder Buffer", func(t *testing.T) { - req := &fast.MockReq{} + req := &mocks.MockReq{} codec := &thriftCodec{FastRead | EnableSkipDecoder} // these bytes are mapped to // Msg string `thrift:"Msg,1" json:"Msg"` diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index f2546da553..469bf23c0d 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -29,7 +29,7 @@ import ( "strings" "testing" - "github.com/cloudwego/kitex/internal/mocks/thrift/fast" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -220,7 +220,7 @@ func TestUnmarshalThriftDataFrugal(t *testing.T) { } for _, codec := range successfulCodecs { err := UnmarshalThriftData(context.Background(), codec, "mock", mockReqThrift, req) - checkDecodeResult(t, err, &fast.MockReq{ + checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, StrMap: req.StrMap, @@ -241,7 +241,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) - checkDecodeResult(t, err, &fast.MockReq{ + checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, StrMap: req.StrMap, diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index d7bef3b027..408bdfddda 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -24,7 +24,7 @@ import ( "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/internal/mocks" - mt "github.com/cloudwego/kitex/internal/mocks/thrift/fast" + mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" diff --git a/pkg/utils/fastthrift/fast_thrift.go b/pkg/utils/fastthrift/fast_thrift.go deleted file mode 100644 index c15d40f4dc..0000000000 --- a/pkg/utils/fastthrift/fast_thrift.go +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package fastthrift - -import ( - "github.com/bytedance/gopkg/lang/dirtmake" - - "github.com/cloudwego/kitex/pkg/remote/codec/thrift" -) - -// FastMarshal marshals the msg to buf. The msg should be generated by Kitex tool and implement ThriftMsgFastCodec. -func FastMarshal(msg thrift.ThriftMsgFastCodec) []byte { - buf := dirtmake.Bytes(msg.BLength(), msg.BLength()) - msg.FastWriteNocopy(buf, nil) - return buf -} - -// FastUnmarshal unmarshal the buf into msg. The msg should be generated by Kitex tool and implement ThriftMsgFastCodec. -func FastUnmarshal(buf []byte, msg thrift.ThriftMsgFastCodec) error { - _, err := msg.FastRead(buf) - return err -} diff --git a/pkg/utils/fastthrift/fastthrift.go b/pkg/utils/fastthrift/fastthrift.go new file mode 100644 index 0000000000..0d27d849f1 --- /dev/null +++ b/pkg/utils/fastthrift/fastthrift.go @@ -0,0 +1,82 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastthrift + +import ( + "errors" + + "github.com/bytedance/gopkg/lang/dirtmake" + + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" +) + +// FastMarshal marshals the msg to buf. The msg should be generated by Kitex tool and implement ThriftFastCodec. +func FastMarshal(msg bthrift.ThriftFastCodec) []byte { + sz := msg.BLength() + buf := dirtmake.Bytes(sz, sz) + msg.FastWriteNocopy(buf, nil) + return buf +} + +// FastUnmarshal unmarshal the buf into msg. The msg should be generated by Kitex tool and implement ThriftFastCodec. +func FastUnmarshal(buf []byte, msg bthrift.ThriftFastCodec) error { + _, err := msg.FastRead(buf) + return err +} + +// for msgType of MarshalMsg +// Please use theses consts instead of relying on apache thrift.TMessageType +const ( + CALL = uint8(1) + REPLY = uint8(2) + EXCEPTION = uint8(3) + ONEWAY = uint8(4) +) + +// MarshalMsg encodes the given msg to buf for generic thrift RPC. +func MarshalMsg(method string, msgType uint8, seq int32, msg bthrift.ThriftFastCodec) ([]byte, error) { + if method == "" { + return nil, errors.New("method not set") + } + sz := bthrift.Binary.MessageBeginLength(method, thrift.TMessageType(msgType), seq) + msg.BLength() + b := dirtmake.Bytes(sz, sz) + i := bthrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seq) + _ = msg.FastWriteNocopy(b[i:], nil) + return b, nil +} + +// UnmarshalMsg parses the given buf and stores the result to msg for generic thrift RPC. +// for EXCEPTION msgType, it will returns `err` with *bthrift.ApplicationException type without storing the result to msg. +func UnmarshalMsg(b []byte, msg bthrift.ThriftFastCodec) (method string, seq int32, err error) { + method, msgType, seq, i, err := bthrift.Binary.ReadMessageBegin(b) + if err != nil { + return "", 0, err + } + b = b[i:] + + if uint8(msgType) == EXCEPTION { + ex := bthrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") + _, err = ex.FastRead(b) + if err != nil { + return method, seq, err + } + return method, seq, ex + } + _, err = msg.FastRead(b) + return method, seq, err +} diff --git a/pkg/utils/fastthrift/fast_thrift_test.go b/pkg/utils/fastthrift/fastthrift_test.go similarity index 60% rename from pkg/utils/fastthrift/fast_thrift_test.go rename to pkg/utils/fastthrift/fastthrift_test.go index 7336ed3839..31525da8fa 100644 --- a/pkg/utils/fastthrift/fast_thrift_test.go +++ b/pkg/utils/fastthrift/fastthrift_test.go @@ -19,8 +19,10 @@ package fastthrift import ( "testing" - mocks "github.com/cloudwego/kitex/internal/mocks/thrift/fast" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) var ( @@ -49,6 +51,35 @@ func TestFastThrift(t *testing.T) { test.Assert(t, len(req1.StrList) == len(req2.StrList)) } +func TestMarshalMsg(t *testing.T) { + // CALL and REPLY + + req := &mocks.MockReq{} + req.Msg = "Hello" + b, err := MarshalMsg("Echo", CALL, 1, req) + test.Assert(t, err == nil, err) + + resp := &mocks.MockReq{} + method, seq, err := UnmarshalMsg(b, resp) + test.Assert(t, err == nil, err) + test.Assert(t, method == "Echo", method) + test.Assert(t, seq == 1, seq) + test.Assert(t, resp.Msg == req.Msg, resp.Msg) + + // EXCEPTION + + ex := bthrift.NewApplicationException(thrift.WRONG_METHOD_NAME, "Ex!") + b, err = MarshalMsg("ExMethod", EXCEPTION, 2, ex) + test.Assert(t, err == nil, err) + method, seq, err = UnmarshalMsg(b, nil) + test.Assert(t, err != nil) + test.Assert(t, method == "ExMethod") + test.Assert(t, seq == 2) + e, ok := err.(*bthrift.ApplicationException) + test.Assert(t, ok) + test.Assert(t, e.TypeID() == ex.TypeID() && e.Error() == ex.Error()) +} + func BenchmarkFastUnmarshal(b *testing.B) { buf := FastMarshal(newRequest()) b.ResetTimer() diff --git a/pkg/utils/kitexutil/kitexutil_test.go b/pkg/utils/kitexutil/kitexutil_test.go index 5c76097494..6766aafcad 100644 --- a/pkg/utils/kitexutil/kitexutil_test.go +++ b/pkg/utils/kitexutil/kitexutil_test.go @@ -23,7 +23,7 @@ import ( "reflect" "testing" - mocks "github.com/cloudwego/kitex/internal/mocks/thrift/fast" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/utils" From 020a96a6e114480c55f5b56ae62b18f80ab94f97 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Mon, 8 Jul 2024 15:03:01 +0800 Subject: [PATCH 08/70] feat(tool): no apache for fastcodec (#1426) * use bthrift.PrependError * updated kitex tool MinDepVersion to 0.11.0 * fixed `undefined: bthrift.KitexUnusedProtection` --- internal/mocks/thrift/k-test.go | 77 +- pkg/protocol/bthrift/test/gen.sh | 2 +- .../bthrift/test/kitex_gen/test/k-test.go | 448 +-- .../bthrift/test/kitex_gen/test/test.go | 3558 +---------------- pkg/protocol/bthrift/test/unknown_test.go | 38 - tool/cmd/kitex/main.go | 2 +- .../pluginmode/thriftgo/file_tpl.go | 2 +- .../pluginmode/thriftgo/patcher.go | 7 + .../pluginmode/thriftgo/struct_tpl.go | 12 +- version.go | 2 +- 10 files changed, 417 insertions(+), 3731 deletions(-) diff --git a/internal/mocks/thrift/k-test.go b/internal/mocks/thrift/k-test.go index 2020f27c4d..f21df2427a 100644 --- a/internal/mocks/thrift/k-test.go +++ b/internal/mocks/thrift/k-test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -// Code generated by Kitex v0.10.1. DO NOT EDIT. +// Code generated by Kitex v0.11.0. DO NOT EDIT. package thrift @@ -24,9 +24,8 @@ import ( "reflect" "strings" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // unused protection @@ -125,17 +124,17 @@ func (p *MockReq) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockReq[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockReq[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *MockReq) FastReadField1(buf []byte) (int, error) { @@ -404,17 +403,17 @@ func (p *Exception) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Exception[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Exception[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *Exception) FastReadField1(buf []byte) (int, error) { @@ -568,17 +567,17 @@ func (p *MockTestArgs) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestArgs[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestArgs[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *MockTestArgs) FastReadField1(buf []byte) (int, error) { @@ -694,17 +693,17 @@ func (p *MockTestResult) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestResult[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestResult[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *MockTestResult) FastReadField0(buf []byte) (int, error) { @@ -827,17 +826,17 @@ func (p *MockExceptionTestArgs) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestArgs[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestArgs[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *MockExceptionTestArgs) FastReadField1(buf []byte) (int, error) { @@ -967,17 +966,17 @@ func (p *MockExceptionTestResult) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestResult[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestResult[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *MockExceptionTestResult) FastReadField0(buf []byte) (int, error) { diff --git a/pkg/protocol/bthrift/test/gen.sh b/pkg/protocol/bthrift/test/gen.sh index 178baf5bb7..e064be9d98 100755 --- a/pkg/protocol/bthrift/test/gen.sh +++ b/pkg/protocol/bthrift/test/gen.sh @@ -1,4 +1,4 @@ #!/bin/bash -kitex -module github.com/cloudwego/kitex -thrift keep_unknown_fields test.thrift +kitex -thrift no_default_serdes -module github.com/cloudwego/kitex -thrift keep_unknown_fields test.thrift diff --git a/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go b/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go index 88823b8ccf..fd5929e322 100644 --- a/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go +++ b/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go @@ -1,4 +1,4 @@ -// Code generated by Kitex v0.7.0. DO NOT EDIT. +// Code generated by Kitex v0.11.0. DO NOT EDIT. package test @@ -8,9 +8,8 @@ import ( "reflect" "strings" - "github.com/apache/thrift/lib/go/thrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // unused protection @@ -157,43 +156,47 @@ func (p *Inner) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Inner[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Inner[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *Inner) FastReadField1(buf []byte) (int, error) { offset := 0 + var _field int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Num = v + _field = v } + p.Num = _field return offset, nil } func (p *Inner) FastReadField2(buf []byte) (int, error) { offset := 0 + var _field *string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Desc = &v + _field = &v } + p.Desc = _field return offset, nil } @@ -205,7 +208,7 @@ func (p *Inner) FastReadField3(buf []byte) (int, error) { if err != nil { return offset, err } - p.MapOfList = make(map[int64][]int64, size) + _field := make(map[int64][]int64, size) for i := 0; i < size; i++ { var _key int64 if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { @@ -242,13 +245,14 @@ func (p *Inner) FastReadField3(buf []byte) (int, error) { offset += l } - p.MapOfList[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.MapOfList = _field return offset, nil } @@ -260,7 +264,7 @@ func (p *Inner) FastReadField4(buf []byte) (int, error) { if err != nil { return offset, err } - p.MapOfEnumKey = make(map[AEnum]int64, size) + _field := make(map[AEnum]int64, size) for i := 0; i < size; i++ { var _key AEnum if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -282,39 +286,44 @@ func (p *Inner) FastReadField4(buf []byte) (int, error) { } - p.MapOfEnumKey[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.MapOfEnumKey = _field return offset, nil } func (p *Inner) FastReadField5(buf []byte) (int, error) { offset := 0 + var _field *int8 if v, l, err := bthrift.Binary.ReadByte(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Byte1 = &v + _field = &v } + p.Byte1 = _field return offset, nil } func (p *Inner) FastReadField6(buf []byte) (int, error) { offset := 0 + var _field *float64 if v, l, err := bthrift.Binary.ReadDouble(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Double1 = &v + _field = &v } + p.Double1 = _field return offset, nil } @@ -362,7 +371,6 @@ func (p *Inner) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) i if p.IsSetNum() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Num", thrift.I32, 1) offset += bthrift.Binary.WriteI32(buf[offset:], p.Num) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -373,7 +381,6 @@ func (p *Inner) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) i if p.IsSetDesc() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "desc", thrift.STRING, 2) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, *p.Desc) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -388,16 +395,13 @@ func (p *Inner) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWriter) i var length int for k, v := range p.MapOfList { length++ - offset += bthrift.Binary.WriteI64(buf[offset:], k) - listBeginOffset := offset offset += bthrift.Binary.ListBeginLength(thrift.I64, 0) var length int for _, v := range v { length++ offset += bthrift.Binary.WriteI64(buf[offset:], v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I64, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -418,11 +422,8 @@ func (p *Inner) fastWriteField4(buf []byte, binaryWriter bthrift.BinaryWriter) i var length int for k, v := range p.MapOfEnumKey { length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - offset += bthrift.Binary.WriteI64(buf[offset:], v) - } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.I64, length) offset += bthrift.Binary.WriteMapEnd(buf[offset:]) @@ -436,7 +437,6 @@ func (p *Inner) fastWriteField5(buf []byte, binaryWriter bthrift.BinaryWriter) i if p.IsSetByte1() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Byte1", thrift.BYTE, 5) offset += bthrift.Binary.WriteByte(buf[offset:], *p.Byte1) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -447,7 +447,6 @@ func (p *Inner) fastWriteField6(buf []byte, binaryWriter bthrift.BinaryWriter) i if p.IsSetDouble1() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Double1", thrift.DOUBLE, 6) offset += bthrift.Binary.WriteDouble(buf[offset:], *p.Double1) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -458,7 +457,6 @@ func (p *Inner) field1Length() int { if p.IsSetNum() { l += bthrift.Binary.FieldBeginLength("Num", thrift.I32, 1) l += bthrift.Binary.I32Length(p.Num) - l += bthrift.Binary.FieldEndLength() } return l @@ -469,7 +467,6 @@ func (p *Inner) field2Length() int { if p.IsSetDesc() { l += bthrift.Binary.FieldBeginLength("desc", thrift.STRING, 2) l += bthrift.Binary.StringLengthNocopy(*p.Desc) - l += bthrift.Binary.FieldEndLength() } return l @@ -483,7 +480,6 @@ func (p *Inner) field3Length() int { for k, v := range p.MapOfList { l += bthrift.Binary.I64Length(k) - l += bthrift.Binary.ListBeginLength(thrift.I64, len(v)) var tmpV int64 l += bthrift.Binary.I64Length(int64(tmpV)) * len(v) @@ -503,9 +499,7 @@ func (p *Inner) field4Length() int { for k, v := range p.MapOfEnumKey { l += bthrift.Binary.I32Length(int32(k)) - l += bthrift.Binary.I64Length(v) - } l += bthrift.Binary.MapEndLength() l += bthrift.Binary.FieldEndLength() @@ -518,7 +512,6 @@ func (p *Inner) field5Length() int { if p.IsSetByte1() { l += bthrift.Binary.FieldBeginLength("Byte1", thrift.BYTE, 5) l += bthrift.Binary.ByteLength(*p.Byte1) - l += bthrift.Binary.FieldEndLength() } return l @@ -529,7 +522,6 @@ func (p *Inner) field6Length() int { if p.IsSetDouble1() { l += bthrift.Binary.FieldBeginLength("Double1", thrift.DOUBLE, 6) l += bthrift.Binary.DoubleLength(*p.Double1) - l += bthrift.Binary.FieldEndLength() } return l @@ -599,30 +591,32 @@ func (p *Local) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Local[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Local[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *Local) FastReadField1(buf []byte) (int, error) { offset := 0 + var _field int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.L = v + _field = v } + p.L = _field return offset, nil } @@ -659,7 +653,6 @@ func (p *Local) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) i offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "l", thrift.I32, 1) offset += bthrift.Binary.WriteI32(buf[offset:], p.L) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -668,7 +661,6 @@ func (p *Local) field1Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("l", thrift.I32, 1) l += bthrift.Binary.I32Length(p.L) - l += bthrift.Binary.FieldEndLength() return l } @@ -1156,17 +1148,17 @@ func (p *FullStruct) FastRead(buf []byte) (int, error) { } return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FullStruct[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FullStruct[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) RequiredFieldNotSetError: return offset, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_FullStruct[fieldId])) } @@ -1174,83 +1166,92 @@ RequiredFieldNotSetError: func (p *FullStruct) FastReadField1(buf []byte) (int, error) { offset := 0 + var _field int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Left = v + _field = v } + p.Left = _field return offset, nil } func (p *FullStruct) FastReadField2(buf []byte) (int, error) { offset := 0 + var _field int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Right = v + _field = v } + p.Right = _field return offset, nil } func (p *FullStruct) FastReadField3(buf []byte) (int, error) { offset := 0 + var _field []byte if v, l, err := bthrift.Binary.ReadBinary(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Dummy = []byte(v) + _field = []byte(v) } + p.Dummy = _field return offset, nil } func (p *FullStruct) FastReadField4(buf []byte) (int, error) { offset := 0 - - tmp := NewInner() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewInner() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.InnerReq = tmp + p.InnerReq = _field return offset, nil } func (p *FullStruct) FastReadField5(buf []byte) (int, error) { offset := 0 + var _field HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Status = HTTPStatus(v) + _field = HTTPStatus(v) } + p.Status = _field return offset, nil } func (p *FullStruct) FastReadField6(buf []byte) (int, error) { offset := 0 + var _field string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Str = v + _field = v } + p.Str = _field return offset, nil } @@ -1262,7 +1263,7 @@ func (p *FullStruct) FastReadField7(buf []byte) (int, error) { if err != nil { return offset, err } - p.EnumList = make([]HTTPStatus, 0, size) + _field := make([]HTTPStatus, 0, size) for i := 0; i < size; i++ { var _elem HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -1274,13 +1275,14 @@ func (p *FullStruct) FastReadField7(buf []byte) (int, error) { } - p.EnumList = append(p.EnumList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.EnumList = _field return offset, nil } @@ -1292,7 +1294,7 @@ func (p *FullStruct) FastReadField8(buf []byte) (int, error) { if err != nil { return offset, err } - p.Strmap = make(map[int32]string, size) + _field := make(map[int32]string, size) for i := 0; i < size; i++ { var _key int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -1314,27 +1316,30 @@ func (p *FullStruct) FastReadField8(buf []byte) (int, error) { } - p.Strmap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.Strmap = _field return offset, nil } func (p *FullStruct) FastReadField9(buf []byte) (int, error) { offset := 0 + var _field int64 if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Int64 = v + _field = v } + p.Int64 = _field return offset, nil } @@ -1346,7 +1351,7 @@ func (p *FullStruct) FastReadField10(buf []byte) (int, error) { if err != nil { return offset, err } - p.IntList = make([]int32, 0, size) + _field := make([]int32, 0, size) for i := 0; i < size; i++ { var _elem int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -1358,13 +1363,14 @@ func (p *FullStruct) FastReadField10(buf []byte) (int, error) { } - p.IntList = append(p.IntList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.IntList = _field return offset, nil } @@ -1376,22 +1382,25 @@ func (p *FullStruct) FastReadField11(buf []byte) (int, error) { if err != nil { return offset, err } - p.LocalList = make([]*Local, 0, size) + _field := make([]*Local, 0, size) + values := make([]Local, size) for i := 0; i < size; i++ { - _elem := NewLocal() + _elem := &values[i] + _elem.InitDefault() if l, err := _elem.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.LocalList = append(p.LocalList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.LocalList = _field return offset, nil } @@ -1403,7 +1412,8 @@ func (p *FullStruct) FastReadField12(buf []byte) (int, error) { if err != nil { return offset, err } - p.StrLocalMap = make(map[string]*Local, size) + _field := make(map[string]*Local, size) + values := make([]Local, size) for i := 0; i < size; i++ { var _key string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -1414,20 +1424,23 @@ func (p *FullStruct) FastReadField12(buf []byte) (int, error) { _key = v } - _val := NewLocal() + + _val := &values[i] + _val.InitDefault() if l, err := _val.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.StrLocalMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.StrLocalMap = _field return offset, nil } @@ -1439,7 +1452,7 @@ func (p *FullStruct) FastReadField13(buf []byte) (int, error) { if err != nil { return offset, err } - p.NestList = make([][]int32, 0, size) + _field := make([][]int32, 0, size) for i := 0; i < size; i++ { _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) offset += l @@ -1466,26 +1479,26 @@ func (p *FullStruct) FastReadField13(buf []byte) (int, error) { offset += l } - p.NestList = append(p.NestList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.NestList = _field return offset, nil } func (p *FullStruct) FastReadField14(buf []byte) (int, error) { offset := 0 - - tmp := NewLocal() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewLocal() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.RequiredIns = tmp + p.RequiredIns = _field return offset, nil } @@ -1497,7 +1510,7 @@ func (p *FullStruct) FastReadField16(buf []byte) (int, error) { if err != nil { return offset, err } - p.NestMap = make(map[string][]string, size) + _field := make(map[string][]string, size) for i := 0; i < size; i++ { var _key string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -1534,13 +1547,14 @@ func (p *FullStruct) FastReadField16(buf []byte) (int, error) { offset += l } - p.NestMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.NestMap = _field return offset, nil } @@ -1552,7 +1566,7 @@ func (p *FullStruct) FastReadField17(buf []byte) (int, error) { if err != nil { return offset, err } - p.NestMap2 = make([]map[string]HTTPStatus, 0, size) + _field := make([]map[string]HTTPStatus, 0, size) for i := 0; i < size; i++ { _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) offset += l @@ -1589,13 +1603,14 @@ func (p *FullStruct) FastReadField17(buf []byte) (int, error) { offset += l } - p.NestMap2 = append(p.NestMap2, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.NestMap2 = _field return offset, nil } @@ -1607,7 +1622,7 @@ func (p *FullStruct) FastReadField18(buf []byte) (int, error) { if err != nil { return offset, err } - p.EnumMap = make(map[int32]HTTPStatus, size) + _field := make(map[int32]HTTPStatus, size) for i := 0; i < size; i++ { var _key int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -1629,13 +1644,14 @@ func (p *FullStruct) FastReadField18(buf []byte) (int, error) { } - p.EnumMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.EnumMap = _field return offset, nil } @@ -1647,7 +1663,7 @@ func (p *FullStruct) FastReadField19(buf []byte) (int, error) { if err != nil { return offset, err } - p.Strlist = make([]string, 0, size) + _field := make([]string, 0, size) for i := 0; i < size; i++ { var _elem string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -1659,39 +1675,38 @@ func (p *FullStruct) FastReadField19(buf []byte) (int, error) { } - p.Strlist = append(p.Strlist, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.Strlist = _field return offset, nil } func (p *FullStruct) FastReadField20(buf []byte) (int, error) { offset := 0 - - tmp := NewLocal() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewLocal() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.OptionalIns = tmp + p.OptionalIns = _field return offset, nil } func (p *FullStruct) FastReadField21(buf []byte) (int, error) { offset := 0 - - tmp := NewInner() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewInner() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.AnotherInner = tmp + p.AnotherInner = _field return offset, nil } @@ -1703,7 +1718,7 @@ func (p *FullStruct) FastReadField22(buf []byte) (int, error) { if err != nil { return offset, err } - p.OptNilList = make([]string, 0, size) + _field := make([]string, 0, size) for i := 0; i < size; i++ { var _elem string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -1715,13 +1730,14 @@ func (p *FullStruct) FastReadField22(buf []byte) (int, error) { } - p.OptNilList = append(p.OptNilList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.OptNilList = _field return offset, nil } @@ -1733,7 +1749,7 @@ func (p *FullStruct) FastReadField23(buf []byte) (int, error) { if err != nil { return offset, err } - p.NilList = make([]string, 0, size) + _field := make([]string, 0, size) for i := 0; i < size; i++ { var _elem string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -1745,13 +1761,14 @@ func (p *FullStruct) FastReadField23(buf []byte) (int, error) { } - p.NilList = append(p.NilList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.NilList = _field return offset, nil } @@ -1763,22 +1780,25 @@ func (p *FullStruct) FastReadField24(buf []byte) (int, error) { if err != nil { return offset, err } - p.OptNilInsList = make([]*Inner, 0, size) + _field := make([]*Inner, 0, size) + values := make([]Inner, size) for i := 0; i < size; i++ { - _elem := NewInner() + _elem := &values[i] + _elem.InitDefault() if l, err := _elem.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.OptNilInsList = append(p.OptNilInsList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.OptNilInsList = _field return offset, nil } @@ -1790,37 +1810,42 @@ func (p *FullStruct) FastReadField25(buf []byte) (int, error) { if err != nil { return offset, err } - p.NilInsList = make([]*Inner, 0, size) + _field := make([]*Inner, 0, size) + values := make([]Inner, size) for i := 0; i < size; i++ { - _elem := NewInner() + _elem := &values[i] + _elem.InitDefault() if l, err := _elem.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.NilInsList = append(p.NilInsList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.NilInsList = _field return offset, nil } func (p *FullStruct) FastReadField26(buf []byte) (int, error) { offset := 0 + var _field *HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l tmp := HTTPStatus(v) - p.OptStatus = &tmp + _field = &tmp } + p.OptStatus = _field return offset, nil } @@ -1832,7 +1857,8 @@ func (p *FullStruct) FastReadField27(buf []byte) (int, error) { if err != nil { return offset, err } - p.EnumKeyMap = make(map[HTTPStatus]*Local, size) + _field := make(map[HTTPStatus]*Local, size) + values := make([]Local, size) for i := 0; i < size; i++ { var _key HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -1843,20 +1869,23 @@ func (p *FullStruct) FastReadField27(buf []byte) (int, error) { _key = HTTPStatus(v) } - _val := NewLocal() + + _val := &values[i] + _val.InitDefault() if l, err := _val.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.EnumKeyMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.EnumKeyMap = _field return offset, nil } @@ -1868,7 +1897,7 @@ func (p *FullStruct) FastReadField28(buf []byte) (int, error) { if err != nil { return offset, err } - p.Complex = make(map[HTTPStatus][]map[string]*Local, size) + _field := make(map[HTTPStatus][]map[string]*Local, size) for i := 0; i < size; i++ { var _key HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -1893,6 +1922,7 @@ func (p *FullStruct) FastReadField28(buf []byte) (int, error) { return offset, err } _elem := make(map[string]*Local, size) + values := make([]Local, size) for i := 0; i < size; i++ { var _key1 string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -1903,7 +1933,9 @@ func (p *FullStruct) FastReadField28(buf []byte) (int, error) { _key1 = v } - _val1 := NewLocal() + + _val1 := &values[i] + _val1.InitDefault() if l, err := _val1.FastRead(buf[offset:]); err != nil { return offset, err } else { @@ -1926,13 +1958,14 @@ func (p *FullStruct) FastReadField28(buf []byte) (int, error) { offset += l } - p.Complex[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.Complex = _field return offset, nil } @@ -1944,7 +1977,7 @@ func (p *FullStruct) FastReadField29(buf []byte) (int, error) { if err != nil { return offset, err } - p.I64Set = make([]int64, 0, size) + _field := make([]int64, 0, size) for i := 0; i < size; i++ { var _elem int64 if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { @@ -1956,41 +1989,46 @@ func (p *FullStruct) FastReadField29(buf []byte) (int, error) { } - p.I64Set = append(p.I64Set, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadSetEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.I64Set = _field return offset, nil } func (p *FullStruct) FastReadField30(buf []byte) (int, error) { offset := 0 + var _field int16 if v, l, err := bthrift.Binary.ReadI16(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Int16 = v + _field = v } + p.Int16 = _field return offset, nil } func (p *FullStruct) FastReadField31(buf []byte) (int, error) { offset := 0 + var _field bool if v, l, err := bthrift.Binary.ReadBool(buf[offset:]); err != nil { return offset, err } else { offset += l - p.IsSet = v + _field = v } + p.IsSet = _field return offset, nil } @@ -2085,7 +2123,6 @@ func (p *FullStruct) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWrit offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Left", thrift.I32, 1) offset += bthrift.Binary.WriteI32(buf[offset:], p.Left) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2095,7 +2132,6 @@ func (p *FullStruct) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWrit if p.IsSetRight() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Right", thrift.I32, 2) offset += bthrift.Binary.WriteI32(buf[offset:], p.Right) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -2105,7 +2141,6 @@ func (p *FullStruct) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWrit offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Dummy", thrift.STRING, 3) offset += bthrift.Binary.WriteBinaryNocopy(buf[offset:], binaryWriter, []byte(p.Dummy)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2122,7 +2157,6 @@ func (p *FullStruct) fastWriteField5(buf []byte, binaryWriter bthrift.BinaryWrit offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "status", thrift.I32, 5) offset += bthrift.Binary.WriteI32(buf[offset:], int32(p.Status)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2131,7 +2165,6 @@ func (p *FullStruct) fastWriteField6(buf []byte, binaryWriter bthrift.BinaryWrit offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Str", thrift.STRING, 6) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Str) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2145,7 +2178,6 @@ func (p *FullStruct) fastWriteField7(buf []byte, binaryWriter bthrift.BinaryWrit for _, v := range p.EnumList { length++ offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2162,11 +2194,8 @@ func (p *FullStruct) fastWriteField8(buf []byte, binaryWriter bthrift.BinaryWrit var length int for k, v := range p.Strmap { length++ - offset += bthrift.Binary.WriteI32(buf[offset:], k) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.STRING, length) offset += bthrift.Binary.WriteMapEnd(buf[offset:]) @@ -2179,7 +2208,6 @@ func (p *FullStruct) fastWriteField9(buf []byte, binaryWriter bthrift.BinaryWrit offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Int64", thrift.I64, 9) offset += bthrift.Binary.WriteI64(buf[offset:], p.Int64) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2194,7 +2222,6 @@ func (p *FullStruct) fastWriteField10(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range p.IntList { length++ offset += bthrift.Binary.WriteI32(buf[offset:], v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2227,9 +2254,7 @@ func (p *FullStruct) fastWriteField12(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range p.StrLocalMap { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) @@ -2252,7 +2277,6 @@ func (p *FullStruct) fastWriteField13(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range v { length++ offset += bthrift.Binary.WriteI32(buf[offset:], v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2279,16 +2303,13 @@ func (p *FullStruct) fastWriteField16(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range p.NestMap { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - listBeginOffset := offset offset += bthrift.Binary.ListBeginLength(thrift.STRING, 0) var length int for _, v := range v { length++ offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2312,11 +2333,8 @@ func (p *FullStruct) fastWriteField17(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range v { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.I32, length) offset += bthrift.Binary.WriteMapEnd(buf[offset:]) @@ -2335,11 +2353,8 @@ func (p *FullStruct) fastWriteField18(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range p.EnumMap { length++ - offset += bthrift.Binary.WriteI32(buf[offset:], k) - offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.I32, length) offset += bthrift.Binary.WriteMapEnd(buf[offset:]) @@ -2356,7 +2371,6 @@ func (p *FullStruct) fastWriteField19(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range p.Strlist { length++ offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2392,7 +2406,6 @@ func (p *FullStruct) fastWriteField22(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range p.OptNilList { length++ offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2410,7 +2423,6 @@ func (p *FullStruct) fastWriteField23(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range p.NilList { length++ offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -2457,7 +2469,6 @@ func (p *FullStruct) fastWriteField26(buf []byte, binaryWriter bthrift.BinaryWri if p.IsSetOptStatus() { offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "opt_status", thrift.I32, 26) offset += bthrift.Binary.WriteI32(buf[offset:], int32(*p.OptStatus)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) } return offset @@ -2471,9 +2482,7 @@ func (p *FullStruct) fastWriteField27(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range p.EnumKeyMap { length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.STRUCT, length) @@ -2490,9 +2499,7 @@ func (p *FullStruct) fastWriteField28(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range p.Complex { length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - listBeginOffset := offset offset += bthrift.Binary.ListBeginLength(thrift.MAP, 0) var length int @@ -2503,9 +2510,7 @@ func (p *FullStruct) fastWriteField28(buf []byte, binaryWriter bthrift.BinaryWri var length int for k, v := range v { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) @@ -2542,7 +2547,6 @@ func (p *FullStruct) fastWriteField29(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range p.I64Set { length++ offset += bthrift.Binary.WriteI64(buf[offset:], v) - } bthrift.Binary.WriteSetBegin(buf[setBeginOffset:], thrift.I64, length) offset += bthrift.Binary.WriteSetEnd(buf[offset:]) @@ -2554,7 +2558,6 @@ func (p *FullStruct) fastWriteField30(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Int16", thrift.I16, 30) offset += bthrift.Binary.WriteI16(buf[offset:], p.Int16) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2563,7 +2566,6 @@ func (p *FullStruct) fastWriteField31(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "isSet", thrift.BOOL, 31) offset += bthrift.Binary.WriteBool(buf[offset:], p.IsSet) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -2572,7 +2574,6 @@ func (p *FullStruct) field1Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Left", thrift.I32, 1) l += bthrift.Binary.I32Length(p.Left) - l += bthrift.Binary.FieldEndLength() return l } @@ -2582,7 +2583,6 @@ func (p *FullStruct) field2Length() int { if p.IsSetRight() { l += bthrift.Binary.FieldBeginLength("Right", thrift.I32, 2) l += bthrift.Binary.I32Length(p.Right) - l += bthrift.Binary.FieldEndLength() } return l @@ -2592,7 +2592,6 @@ func (p *FullStruct) field3Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Dummy", thrift.STRING, 3) l += bthrift.Binary.BinaryLengthNocopy([]byte(p.Dummy)) - l += bthrift.Binary.FieldEndLength() return l } @@ -2609,7 +2608,6 @@ func (p *FullStruct) field5Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("status", thrift.I32, 5) l += bthrift.Binary.I32Length(int32(p.Status)) - l += bthrift.Binary.FieldEndLength() return l } @@ -2618,7 +2616,6 @@ func (p *FullStruct) field6Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Str", thrift.STRING, 6) l += bthrift.Binary.StringLengthNocopy(p.Str) - l += bthrift.Binary.FieldEndLength() return l } @@ -2629,7 +2626,6 @@ func (p *FullStruct) field7Length() int { l += bthrift.Binary.ListBeginLength(thrift.I32, len(p.EnumList)) for _, v := range p.EnumList { l += bthrift.Binary.I32Length(int32(v)) - } l += bthrift.Binary.ListEndLength() l += bthrift.Binary.FieldEndLength() @@ -2644,9 +2640,7 @@ func (p *FullStruct) field8Length() int { for k, v := range p.Strmap { l += bthrift.Binary.I32Length(k) - l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.MapEndLength() l += bthrift.Binary.FieldEndLength() @@ -2658,7 +2652,6 @@ func (p *FullStruct) field9Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Int64", thrift.I64, 9) l += bthrift.Binary.I64Length(p.Int64) - l += bthrift.Binary.FieldEndLength() return l } @@ -2695,7 +2688,6 @@ func (p *FullStruct) field12Length() int { for k, v := range p.StrLocalMap { l += bthrift.Binary.StringLengthNocopy(k) - l += v.BLength() } l += bthrift.Binary.MapEndLength() @@ -2733,11 +2725,9 @@ func (p *FullStruct) field16Length() int { for k, v := range p.NestMap { l += bthrift.Binary.StringLengthNocopy(k) - l += bthrift.Binary.ListBeginLength(thrift.STRING, len(v)) for _, v := range v { l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.ListEndLength() } @@ -2755,9 +2745,7 @@ func (p *FullStruct) field17Length() int { for k, v := range v { l += bthrift.Binary.StringLengthNocopy(k) - l += bthrift.Binary.I32Length(int32(v)) - } l += bthrift.Binary.MapEndLength() } @@ -2773,9 +2761,7 @@ func (p *FullStruct) field18Length() int { for k, v := range p.EnumMap { l += bthrift.Binary.I32Length(k) - l += bthrift.Binary.I32Length(int32(v)) - } l += bthrift.Binary.MapEndLength() l += bthrift.Binary.FieldEndLength() @@ -2788,7 +2774,6 @@ func (p *FullStruct) field19Length() int { l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.Strlist)) for _, v := range p.Strlist { l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.ListEndLength() l += bthrift.Binary.FieldEndLength() @@ -2820,7 +2805,6 @@ func (p *FullStruct) field22Length() int { l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.OptNilList)) for _, v := range p.OptNilList { l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.ListEndLength() l += bthrift.Binary.FieldEndLength() @@ -2834,7 +2818,6 @@ func (p *FullStruct) field23Length() int { l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.NilList)) for _, v := range p.NilList { l += bthrift.Binary.StringLengthNocopy(v) - } l += bthrift.Binary.ListEndLength() l += bthrift.Binary.FieldEndLength() @@ -2872,7 +2855,6 @@ func (p *FullStruct) field26Length() int { if p.IsSetOptStatus() { l += bthrift.Binary.FieldBeginLength("opt_status", thrift.I32, 26) l += bthrift.Binary.I32Length(int32(*p.OptStatus)) - l += bthrift.Binary.FieldEndLength() } return l @@ -2885,7 +2867,6 @@ func (p *FullStruct) field27Length() int { for k, v := range p.EnumKeyMap { l += bthrift.Binary.I32Length(int32(k)) - l += v.BLength() } l += bthrift.Binary.MapEndLength() @@ -2900,14 +2881,12 @@ func (p *FullStruct) field28Length() int { for k, v := range p.Complex { l += bthrift.Binary.I32Length(int32(k)) - l += bthrift.Binary.ListBeginLength(thrift.MAP, len(v)) for _, v := range v { l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, len(v)) for k, v := range v { l += bthrift.Binary.StringLengthNocopy(k) - l += v.BLength() } l += bthrift.Binary.MapEndLength() @@ -2947,7 +2926,6 @@ func (p *FullStruct) field30Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Int16", thrift.I16, 30) l += bthrift.Binary.I16Length(p.Int16) - l += bthrift.Binary.FieldEndLength() return l } @@ -2956,7 +2934,6 @@ func (p *FullStruct) field31Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("isSet", thrift.BOOL, 31) l += bthrift.Binary.BoolLength(p.IsSet) - l += bthrift.Binary.FieldEndLength() return l } @@ -3206,17 +3183,17 @@ func (p *MixedStruct) FastRead(buf []byte) (int, error) { } return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MixedStruct[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MixedStruct[fieldId]), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) RequiredFieldNotSetError: return offset, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_MixedStruct[fieldId])) } @@ -3224,42 +3201,48 @@ RequiredFieldNotSetError: func (p *MixedStruct) FastReadField1(buf []byte) (int, error) { offset := 0 + var _field int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Left = v + _field = v } + p.Left = _field return offset, nil } func (p *MixedStruct) FastReadField3(buf []byte) (int, error) { offset := 0 + var _field []byte if v, l, err := bthrift.Binary.ReadBinary(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Dummy = []byte(v) + _field = []byte(v) } + p.Dummy = _field return offset, nil } func (p *MixedStruct) FastReadField6(buf []byte) (int, error) { offset := 0 + var _field string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Str = v + _field = v } + p.Str = _field return offset, nil } @@ -3271,7 +3254,7 @@ func (p *MixedStruct) FastReadField7(buf []byte) (int, error) { if err != nil { return offset, err } - p.EnumList = make([]HTTPStatus, 0, size) + _field := make([]HTTPStatus, 0, size) for i := 0; i < size; i++ { var _elem HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -3283,27 +3266,30 @@ func (p *MixedStruct) FastReadField7(buf []byte) (int, error) { } - p.EnumList = append(p.EnumList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.EnumList = _field return offset, nil } func (p *MixedStruct) FastReadField9(buf []byte) (int, error) { offset := 0 + var _field int64 if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { return offset, err } else { offset += l - p.Int64 = v + _field = v } + p.Int64 = _field return offset, nil } @@ -3315,7 +3301,7 @@ func (p *MixedStruct) FastReadField10(buf []byte) (int, error) { if err != nil { return offset, err } - p.IntList = make([]int32, 0, size) + _field := make([]int32, 0, size) for i := 0; i < size; i++ { var _elem int32 if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -3327,13 +3313,14 @@ func (p *MixedStruct) FastReadField10(buf []byte) (int, error) { } - p.IntList = append(p.IntList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.IntList = _field return offset, nil } @@ -3345,22 +3332,25 @@ func (p *MixedStruct) FastReadField11(buf []byte) (int, error) { if err != nil { return offset, err } - p.LocalList = make([]*Local, 0, size) + _field := make([]*Local, 0, size) + values := make([]Local, size) for i := 0; i < size; i++ { - _elem := NewLocal() + _elem := &values[i] + _elem.InitDefault() if l, err := _elem.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.LocalList = append(p.LocalList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.LocalList = _field return offset, nil } @@ -3372,7 +3362,8 @@ func (p *MixedStruct) FastReadField12(buf []byte) (int, error) { if err != nil { return offset, err } - p.StrLocalMap = make(map[string]*Local, size) + _field := make(map[string]*Local, size) + values := make([]Local, size) for i := 0; i < size; i++ { var _key string if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { @@ -3383,20 +3374,23 @@ func (p *MixedStruct) FastReadField12(buf []byte) (int, error) { _key = v } - _val := NewLocal() + + _val := &values[i] + _val.InitDefault() if l, err := _val.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.StrLocalMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.StrLocalMap = _field return offset, nil } @@ -3408,7 +3402,7 @@ func (p *MixedStruct) FastReadField13(buf []byte) (int, error) { if err != nil { return offset, err } - p.NestList = make([][]int32, 0, size) + _field := make([][]int32, 0, size) for i := 0; i < size; i++ { _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) offset += l @@ -3435,52 +3429,50 @@ func (p *MixedStruct) FastReadField13(buf []byte) (int, error) { offset += l } - p.NestList = append(p.NestList, _elem) + _field = append(_field, _elem) } if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.NestList = _field return offset, nil } func (p *MixedStruct) FastReadField14(buf []byte) (int, error) { offset := 0 - - tmp := NewLocal() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewLocal() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.RequiredIns = tmp + p.RequiredIns = _field return offset, nil } func (p *MixedStruct) FastReadField20(buf []byte) (int, error) { offset := 0 - - tmp := NewLocal() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewLocal() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.OptionalIns = tmp + p.OptionalIns = _field return offset, nil } func (p *MixedStruct) FastReadField21(buf []byte) (int, error) { offset := 0 - - tmp := NewInner() - if l, err := tmp.FastRead(buf[offset:]); err != nil { + _field := NewInner() + if l, err := _field.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.AnotherInner = tmp + p.AnotherInner = _field return offset, nil } @@ -3492,7 +3484,8 @@ func (p *MixedStruct) FastReadField27(buf []byte) (int, error) { if err != nil { return offset, err } - p.EnumKeyMap = make(map[HTTPStatus]*Local, size) + _field := make(map[HTTPStatus]*Local, size) + values := make([]Local, size) for i := 0; i < size; i++ { var _key HTTPStatus if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { @@ -3503,20 +3496,23 @@ func (p *MixedStruct) FastReadField27(buf []byte) (int, error) { _key = HTTPStatus(v) } - _val := NewLocal() + + _val := &values[i] + _val.InitDefault() if l, err := _val.FastRead(buf[offset:]); err != nil { return offset, err } else { offset += l } - p.EnumKeyMap[_key] = _val + _field[_key] = _val } if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { return offset, err } else { offset += l } + p.EnumKeyMap = _field return offset, nil } @@ -3577,7 +3573,6 @@ func (p *MixedStruct) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Left", thrift.I32, 1) offset += bthrift.Binary.WriteI32(buf[offset:], p.Left) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -3586,7 +3581,6 @@ func (p *MixedStruct) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Dummy", thrift.STRING, 3) offset += bthrift.Binary.WriteBinaryNocopy(buf[offset:], binaryWriter, []byte(p.Dummy)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -3595,7 +3589,6 @@ func (p *MixedStruct) fastWriteField6(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Str", thrift.STRING, 6) offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Str) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -3609,7 +3602,6 @@ func (p *MixedStruct) fastWriteField7(buf []byte, binaryWriter bthrift.BinaryWri for _, v := range p.EnumList { length++ offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -3621,7 +3613,6 @@ func (p *MixedStruct) fastWriteField9(buf []byte, binaryWriter bthrift.BinaryWri offset := 0 offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Int64", thrift.I64, 9) offset += bthrift.Binary.WriteI64(buf[offset:], p.Int64) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) return offset } @@ -3636,7 +3627,6 @@ func (p *MixedStruct) fastWriteField10(buf []byte, binaryWriter bthrift.BinaryWr for _, v := range p.IntList { length++ offset += bthrift.Binary.WriteI32(buf[offset:], v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -3669,9 +3659,7 @@ func (p *MixedStruct) fastWriteField12(buf []byte, binaryWriter bthrift.BinaryWr var length int for k, v := range p.StrLocalMap { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) @@ -3694,7 +3682,6 @@ func (p *MixedStruct) fastWriteField13(buf []byte, binaryWriter bthrift.BinaryWr for _, v := range v { length++ offset += bthrift.Binary.WriteI32(buf[offset:], v) - } bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) offset += bthrift.Binary.WriteListEnd(buf[offset:]) @@ -3739,9 +3726,7 @@ func (p *MixedStruct) fastWriteField27(buf []byte, binaryWriter bthrift.BinaryWr var length int for k, v := range p.EnumKeyMap { length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) } bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.STRUCT, length) @@ -3754,7 +3739,6 @@ func (p *MixedStruct) field1Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Left", thrift.I32, 1) l += bthrift.Binary.I32Length(p.Left) - l += bthrift.Binary.FieldEndLength() return l } @@ -3763,7 +3747,6 @@ func (p *MixedStruct) field3Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Dummy", thrift.STRING, 3) l += bthrift.Binary.BinaryLengthNocopy([]byte(p.Dummy)) - l += bthrift.Binary.FieldEndLength() return l } @@ -3772,7 +3755,6 @@ func (p *MixedStruct) field6Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Str", thrift.STRING, 6) l += bthrift.Binary.StringLengthNocopy(p.Str) - l += bthrift.Binary.FieldEndLength() return l } @@ -3783,7 +3765,6 @@ func (p *MixedStruct) field7Length() int { l += bthrift.Binary.ListBeginLength(thrift.I32, len(p.EnumList)) for _, v := range p.EnumList { l += bthrift.Binary.I32Length(int32(v)) - } l += bthrift.Binary.ListEndLength() l += bthrift.Binary.FieldEndLength() @@ -3794,7 +3775,6 @@ func (p *MixedStruct) field9Length() int { l := 0 l += bthrift.Binary.FieldBeginLength("Int64", thrift.I64, 9) l += bthrift.Binary.I64Length(p.Int64) - l += bthrift.Binary.FieldEndLength() return l } @@ -3831,7 +3811,6 @@ func (p *MixedStruct) field12Length() int { for k, v := range p.StrLocalMap { l += bthrift.Binary.StringLengthNocopy(k) - l += v.BLength() } l += bthrift.Binary.MapEndLength() @@ -3887,7 +3866,6 @@ func (p *MixedStruct) field27Length() int { for k, v := range p.EnumKeyMap { l += bthrift.Binary.I32Length(int32(k)) - l += v.BLength() } l += bthrift.Binary.MapEndLength() @@ -3938,15 +3916,15 @@ func (p *EmptyStruct) FastRead(buf []byte) (int, error) { return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } // for compatibility diff --git a/pkg/protocol/bthrift/test/kitex_gen/test/test.go b/pkg/protocol/bthrift/test/kitex_gen/test/test.go index 1aa7ca7c15..25355a9b17 100644 --- a/pkg/protocol/bthrift/test/kitex_gen/test/test.go +++ b/pkg/protocol/bthrift/test/kitex_gen/test/test.go @@ -1,4 +1,4 @@ -// Code generated by thriftgo (0.3.0). DO NOT EDIT. +// Code generated by thriftgo (0.3.13). DO NOT EDIT. package test @@ -7,7 +7,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - "github.com/apache/thrift/lib/go/thrift" "github.com/cloudwego/thriftgo/generator/golang/extension/unknown" "strings" ) @@ -114,10 +113,7 @@ func NewInner() *Inner { } func (p *Inner) InitDefault() { - *p = Inner{ - - Num: 5, - } + p.Num = 5 } var Inner_Num_DEFAULT int32 = 5 @@ -196,15 +192,6 @@ func (p *Inner) CarryingUnknownFields() bool { return len(p._unknownFields) > 0 } -var fieldIDToName_Inner = map[int16]string{ - 1: "Num", - 2: "desc", - 3: "MapOfList", - 4: "MapOfEnumKey", - 5: "Byte1", - 6: "Double1", -} - func (p *Inner) IsSetNum() bool { return p.Num != Inner_Num_DEFAULT } @@ -229,425 +216,6 @@ func (p *Inner) IsSetDouble1() bool { return p.Double1 != nil } -func (p *Inner) Read(iprot thrift.TProtocol) (err error) { - var name string - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - name, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.STRING { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.MAP { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 4: - if fieldTypeId == thrift.MAP { - if err = p.ReadField4(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 5: - if fieldTypeId == thrift.BYTE { - if err = p.ReadField5(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.DOUBLE { - if err = p.ReadField6(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { - goto UnknownFieldsAppendError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Inner[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -UnknownFieldsAppendError: - return thrift.PrependError(fmt.Sprintf("%T append unknown field(name:%s type:%d id:%d) error: ", p, name, fieldTypeId, fieldId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Inner) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.Num = v - } - return nil -} - -func (p *Inner) ReadField2(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Desc = &v - } - return nil -} - -func (p *Inner) ReadField3(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.MapOfList = make(map[int64][]int64, size) - for i := 0; i < size; i++ { - var _key int64 - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - _key = v - } - - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - _val := make([]int64, 0, size) - for i := 0; i < size; i++ { - var _elem int64 - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - _elem = v - } - - _val = append(_val, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - - p.MapOfList[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *Inner) ReadField4(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.MapOfEnumKey = make(map[AEnum]int64, size) - for i := 0; i < size; i++ { - var _key AEnum - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _key = AEnum(v) - } - - var _val int64 - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - _val = v - } - - p.MapOfEnumKey[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *Inner) ReadField5(iprot thrift.TProtocol) error { - if v, err := iprot.ReadByte(); err != nil { - return err - } else { - p.Byte1 = &v - } - return nil -} - -func (p *Inner) ReadField6(iprot thrift.TProtocol) error { - if v, err := iprot.ReadDouble(); err != nil { - return err - } else { - p.Double1 = &v - } - return nil -} - -func (p *Inner) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Inner"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - if err = p.writeField4(oprot); err != nil { - fieldId = 4 - goto WriteFieldError - } - if err = p.writeField5(oprot); err != nil { - fieldId = 5 - goto WriteFieldError - } - if err = p.writeField6(oprot); err != nil { - fieldId = 6 - goto WriteFieldError - } - - if err = p._unknownFields.Write(oprot); err != nil { - goto UnknownFieldsWriteError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -UnknownFieldsWriteError: - return thrift.PrependError(fmt.Sprintf("%T write unknown fields error: ", p), err) -} - -func (p *Inner) writeField1(oprot thrift.TProtocol) (err error) { - if p.IsSetNum() { - if err = oprot.WriteFieldBegin("Num", thrift.I32, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.Num); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *Inner) writeField2(oprot thrift.TProtocol) (err error) { - if p.IsSetDesc() { - if err = oprot.WriteFieldBegin("desc", thrift.STRING, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(*p.Desc); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *Inner) writeField3(oprot thrift.TProtocol) (err error) { - if p.IsSetMapOfList() { - if err = oprot.WriteFieldBegin("MapOfList", thrift.MAP, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I64, thrift.LIST, len(p.MapOfList)); err != nil { - return err - } - for k, v := range p.MapOfList { - - if err := oprot.WriteI64(k); err != nil { - return err - } - - if err := oprot.WriteListBegin(thrift.I64, len(v)); err != nil { - return err - } - for _, v := range v { - if err := oprot.WriteI64(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) -} - -func (p *Inner) writeField4(oprot thrift.TProtocol) (err error) { - if p.IsSetMapOfEnumKey() { - if err = oprot.WriteFieldBegin("MapOfEnumKey", thrift.MAP, 4); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I32, thrift.I64, len(p.MapOfEnumKey)); err != nil { - return err - } - for k, v := range p.MapOfEnumKey { - - if err := oprot.WriteI32(int32(k)); err != nil { - return err - } - - if err := oprot.WriteI64(v); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) -} - -func (p *Inner) writeField5(oprot thrift.TProtocol) (err error) { - if p.IsSetByte1() { - if err = oprot.WriteFieldBegin("Byte1", thrift.BYTE, 5); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteByte(*p.Byte1); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) -} - -func (p *Inner) writeField6(oprot thrift.TProtocol) (err error) { - if p.IsSetDouble1() { - if err = oprot.WriteFieldBegin("Double1", thrift.DOUBLE, 6); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteDouble(*p.Double1); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) -} - func (p *Inner) String() string { if p == nil { return "" @@ -758,6 +326,15 @@ func (p *Inner) Field6DeepEqual(src *float64) bool { return true } +var fieldIDToName_Inner = map[int16]string{ + 1: "Num", + 2: "desc", + 3: "MapOfList", + 4: "MapOfEnumKey", + 5: "Byte1", + 6: "Double1", +} + type Local struct { L int32 `thrift:"l,1" frugal:"1,default,i32" json:"l"` _unknownFields unknown.Fields @@ -768,7 +345,6 @@ func NewLocal() *Local { } func (p *Local) InitDefault() { - *p = Local{} } func (p *Local) GetL() (v int32) { @@ -782,131 +358,6 @@ func (p *Local) CarryingUnknownFields() bool { return len(p._unknownFields) > 0 } -var fieldIDToName_Local = map[int16]string{ - 1: "l", -} - -func (p *Local) Read(iprot thrift.TProtocol) (err error) { - var name string - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - name, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { - goto UnknownFieldsAppendError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Local[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -UnknownFieldsAppendError: - return thrift.PrependError(fmt.Sprintf("%T append unknown field(name:%s type:%d id:%d) error: ", p, name, fieldTypeId, fieldId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Local) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.L = v - } - return nil -} - -func (p *Local) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Local"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - - if err = p._unknownFields.Write(oprot); err != nil { - goto UnknownFieldsWriteError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -UnknownFieldsWriteError: - return thrift.PrependError(fmt.Sprintf("%T write unknown fields error: ", p), err) -} - -func (p *Local) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("l", thrift.I32, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.L); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - func (p *Local) String() string { if p == nil { return "" @@ -934,6 +385,10 @@ func (p *Local) Field1DeepEqual(src int32) bool { return true } +var fieldIDToName_Local = map[int16]string{ + 1: "l", +} + type FullStruct struct { Left int32 `thrift:"Left,1,required" frugal:"1,required,i32" json:"Left"` Right int32 `thrift:"Right,2,optional" frugal:"2,optional,i32" json:"Right,omitempty"` @@ -976,10 +431,7 @@ func NewFullStruct() *FullStruct { } func (p *FullStruct) InitDefault() { - *p = FullStruct{ - - Right: 3, - } + p.Right = 3 } func (p *FullStruct) GetLeft() (v int32) { @@ -1246,39 +698,6 @@ func (p *FullStruct) CarryingUnknownFields() bool { return len(p._unknownFields) > 0 } -var fieldIDToName_FullStruct = map[int16]string{ - 1: "Left", - 2: "Right", - 3: "Dummy", - 4: "InnerReq", - 5: "status", - 6: "Str", - 7: "enum_list", - 8: "Strmap", - 9: "Int64", - 10: "IntList", - 11: "localList", - 12: "StrLocalMap", - 13: "nestList", - 14: "required_ins", - 16: "nestMap", - 17: "nestMap2", - 18: "enum_map", - 19: "Strlist", - 20: "optional_ins", - 21: "AnotherInner", - 22: "opt_nil_list", - 23: "nil_list", - 24: "opt_nil_ins_list", - 25: "nil_ins_list", - 26: "opt_status", - 27: "enum_key_map", - 28: "complex", - 29: "i64Set", - 30: "Int16", - 31: "isSet", -} - func (p *FullStruct) IsSetRight() bool { return p.Right != FullStruct_Right_DEFAULT } @@ -1319,1869 +738,12 @@ func (p *FullStruct) IsSetOptStatus() bool { return p.OptStatus != nil } -func (p *FullStruct) Read(iprot thrift.TProtocol) (err error) { - var name string - var fieldTypeId thrift.TType - var fieldId int16 - var issetLeft bool = false - var issetRequiredIns bool = false - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError +func (p *FullStruct) String() string { + if p == nil { + return "" } - - for { - name, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - issetLeft = true - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.I32 { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.STRING { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 4: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField4(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 5: - if fieldTypeId == thrift.I32 { - if err = p.ReadField5(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.STRING { - if err = p.ReadField6(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 7: - if fieldTypeId == thrift.LIST { - if err = p.ReadField7(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 8: - if fieldTypeId == thrift.MAP { - if err = p.ReadField8(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 9: - if fieldTypeId == thrift.I64 { - if err = p.ReadField9(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 10: - if fieldTypeId == thrift.LIST { - if err = p.ReadField10(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 11: - if fieldTypeId == thrift.LIST { - if err = p.ReadField11(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 12: - if fieldTypeId == thrift.MAP { - if err = p.ReadField12(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 13: - if fieldTypeId == thrift.LIST { - if err = p.ReadField13(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 14: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField14(iprot); err != nil { - goto ReadFieldError - } - issetRequiredIns = true - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 16: - if fieldTypeId == thrift.MAP { - if err = p.ReadField16(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 17: - if fieldTypeId == thrift.LIST { - if err = p.ReadField17(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 18: - if fieldTypeId == thrift.MAP { - if err = p.ReadField18(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 19: - if fieldTypeId == thrift.LIST { - if err = p.ReadField19(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 20: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField20(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 21: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField21(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 22: - if fieldTypeId == thrift.LIST { - if err = p.ReadField22(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 23: - if fieldTypeId == thrift.LIST { - if err = p.ReadField23(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 24: - if fieldTypeId == thrift.LIST { - if err = p.ReadField24(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 25: - if fieldTypeId == thrift.LIST { - if err = p.ReadField25(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 26: - if fieldTypeId == thrift.I32 { - if err = p.ReadField26(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 27: - if fieldTypeId == thrift.MAP { - if err = p.ReadField27(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 28: - if fieldTypeId == thrift.MAP { - if err = p.ReadField28(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 29: - if fieldTypeId == thrift.SET { - if err = p.ReadField29(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 30: - if fieldTypeId == thrift.I16 { - if err = p.ReadField30(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 31: - if fieldTypeId == thrift.BOOL { - if err = p.ReadField31(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { - goto UnknownFieldsAppendError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - if !issetLeft { - fieldId = 1 - goto RequiredFieldNotSetError - } - - if !issetRequiredIns { - fieldId = 14 - goto RequiredFieldNotSetError - } - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FullStruct[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -UnknownFieldsAppendError: - return thrift.PrependError(fmt.Sprintf("%T append unknown field(name:%s type:%d id:%d) error: ", p, name, fieldTypeId, fieldId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -RequiredFieldNotSetError: - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_FullStruct[fieldId])) -} - -func (p *FullStruct) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.Left = v - } - return nil -} - -func (p *FullStruct) ReadField2(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.Right = v - } - return nil -} - -func (p *FullStruct) ReadField3(iprot thrift.TProtocol) error { - if v, err := iprot.ReadBinary(); err != nil { - return err - } else { - p.Dummy = []byte(v) - } - return nil -} - -func (p *FullStruct) ReadField4(iprot thrift.TProtocol) error { - p.InnerReq = NewInner() - if err := p.InnerReq.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField5(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.Status = HTTPStatus(v) - } - return nil -} - -func (p *FullStruct) ReadField6(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Str = v - } - return nil -} - -func (p *FullStruct) ReadField7(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.EnumList = make([]HTTPStatus, 0, size) - for i := 0; i < size; i++ { - var _elem HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _elem = HTTPStatus(v) - } - - p.EnumList = append(p.EnumList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField8(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.Strmap = make(map[int32]string, size) - for i := 0; i < size; i++ { - var _key int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _key = v - } - - var _val string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _val = v - } - - p.Strmap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField9(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - p.Int64 = v - } - return nil -} - -func (p *FullStruct) ReadField10(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.IntList = make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _elem = v - } - - p.IntList = append(p.IntList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField11(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.LocalList = make([]*Local, 0, size) - for i := 0; i < size; i++ { - _elem := NewLocal() - if err := _elem.Read(iprot); err != nil { - return err - } - - p.LocalList = append(p.LocalList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField12(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.StrLocalMap = make(map[string]*Local, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - _val := NewLocal() - if err := _val.Read(iprot); err != nil { - return err - } - - p.StrLocalMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField13(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.NestList = make([][]int32, 0, size) - for i := 0; i < size; i++ { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - _elem := make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem1 int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _elem1 = v - } - - _elem = append(_elem, _elem1) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - - p.NestList = append(p.NestList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField14(iprot thrift.TProtocol) error { - p.RequiredIns = NewLocal() - if err := p.RequiredIns.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField16(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.NestMap = make(map[string][]string, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - _val := make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _elem = v - } - - _val = append(_val, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - - p.NestMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField17(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.NestMap2 = make([]map[string]HTTPStatus, 0, size) - for i := 0; i < size; i++ { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - _elem := make(map[string]HTTPStatus, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - - var _val HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _val = HTTPStatus(v) - } - - _elem[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - - p.NestMap2 = append(p.NestMap2, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField18(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.EnumMap = make(map[int32]HTTPStatus, size) - for i := 0; i < size; i++ { - var _key int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _key = v - } - - var _val HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _val = HTTPStatus(v) - } - - p.EnumMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField19(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.Strlist = make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _elem = v - } - - p.Strlist = append(p.Strlist, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField20(iprot thrift.TProtocol) error { - p.OptionalIns = NewLocal() - if err := p.OptionalIns.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField21(iprot thrift.TProtocol) error { - p.AnotherInner = NewInner() - if err := p.AnotherInner.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField22(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.OptNilList = make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _elem = v - } - - p.OptNilList = append(p.OptNilList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField23(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.NilList = make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _elem = v - } - - p.NilList = append(p.NilList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField24(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.OptNilInsList = make([]*Inner, 0, size) - for i := 0; i < size; i++ { - _elem := NewInner() - if err := _elem.Read(iprot); err != nil { - return err - } - - p.OptNilInsList = append(p.OptNilInsList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField25(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.NilInsList = make([]*Inner, 0, size) - for i := 0; i < size; i++ { - _elem := NewInner() - if err := _elem.Read(iprot); err != nil { - return err - } - - p.NilInsList = append(p.NilInsList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField26(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - tmp := HTTPStatus(v) - p.OptStatus = &tmp - } - return nil -} - -func (p *FullStruct) ReadField27(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.EnumKeyMap = make(map[HTTPStatus]*Local, size) - for i := 0; i < size; i++ { - var _key HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _key = HTTPStatus(v) - } - _val := NewLocal() - if err := _val.Read(iprot); err != nil { - return err - } - - p.EnumKeyMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField28(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.Complex = make(map[HTTPStatus][]map[string]*Local, size) - for i := 0; i < size; i++ { - var _key HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _key = HTTPStatus(v) - } - - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - _val := make([]map[string]*Local, 0, size) - for i := 0; i < size; i++ { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - _elem := make(map[string]*Local, size) - for i := 0; i < size; i++ { - var _key1 string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key1 = v - } - _val1 := NewLocal() - if err := _val1.Read(iprot); err != nil { - return err - } - - _elem[_key1] = _val1 - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - - _val = append(_val, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - - p.Complex[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField29(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadSetBegin() - if err != nil { - return err - } - p.I64Set = make([]int64, 0, size) - for i := 0; i < size; i++ { - var _elem int64 - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - _elem = v - } - - p.I64Set = append(p.I64Set, _elem) - } - if err := iprot.ReadSetEnd(); err != nil { - return err - } - return nil -} - -func (p *FullStruct) ReadField30(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI16(); err != nil { - return err - } else { - p.Int16 = v - } - return nil -} - -func (p *FullStruct) ReadField31(iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(); err != nil { - return err - } else { - p.IsSet = v - } - return nil -} - -func (p *FullStruct) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("FullStruct"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - if err = p.writeField4(oprot); err != nil { - fieldId = 4 - goto WriteFieldError - } - if err = p.writeField5(oprot); err != nil { - fieldId = 5 - goto WriteFieldError - } - if err = p.writeField6(oprot); err != nil { - fieldId = 6 - goto WriteFieldError - } - if err = p.writeField7(oprot); err != nil { - fieldId = 7 - goto WriteFieldError - } - if err = p.writeField8(oprot); err != nil { - fieldId = 8 - goto WriteFieldError - } - if err = p.writeField9(oprot); err != nil { - fieldId = 9 - goto WriteFieldError - } - if err = p.writeField10(oprot); err != nil { - fieldId = 10 - goto WriteFieldError - } - if err = p.writeField11(oprot); err != nil { - fieldId = 11 - goto WriteFieldError - } - if err = p.writeField12(oprot); err != nil { - fieldId = 12 - goto WriteFieldError - } - if err = p.writeField13(oprot); err != nil { - fieldId = 13 - goto WriteFieldError - } - if err = p.writeField14(oprot); err != nil { - fieldId = 14 - goto WriteFieldError - } - if err = p.writeField16(oprot); err != nil { - fieldId = 16 - goto WriteFieldError - } - if err = p.writeField17(oprot); err != nil { - fieldId = 17 - goto WriteFieldError - } - if err = p.writeField18(oprot); err != nil { - fieldId = 18 - goto WriteFieldError - } - if err = p.writeField19(oprot); err != nil { - fieldId = 19 - goto WriteFieldError - } - if err = p.writeField20(oprot); err != nil { - fieldId = 20 - goto WriteFieldError - } - if err = p.writeField21(oprot); err != nil { - fieldId = 21 - goto WriteFieldError - } - if err = p.writeField22(oprot); err != nil { - fieldId = 22 - goto WriteFieldError - } - if err = p.writeField23(oprot); err != nil { - fieldId = 23 - goto WriteFieldError - } - if err = p.writeField24(oprot); err != nil { - fieldId = 24 - goto WriteFieldError - } - if err = p.writeField25(oprot); err != nil { - fieldId = 25 - goto WriteFieldError - } - if err = p.writeField26(oprot); err != nil { - fieldId = 26 - goto WriteFieldError - } - if err = p.writeField27(oprot); err != nil { - fieldId = 27 - goto WriteFieldError - } - if err = p.writeField28(oprot); err != nil { - fieldId = 28 - goto WriteFieldError - } - if err = p.writeField29(oprot); err != nil { - fieldId = 29 - goto WriteFieldError - } - if err = p.writeField30(oprot); err != nil { - fieldId = 30 - goto WriteFieldError - } - if err = p.writeField31(oprot); err != nil { - fieldId = 31 - goto WriteFieldError - } - - if err = p._unknownFields.Write(oprot); err != nil { - goto UnknownFieldsWriteError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -UnknownFieldsWriteError: - return thrift.PrependError(fmt.Sprintf("%T write unknown fields error: ", p), err) -} - -func (p *FullStruct) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Left", thrift.I32, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.Left); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *FullStruct) writeField2(oprot thrift.TProtocol) (err error) { - if p.IsSetRight() { - if err = oprot.WriteFieldBegin("Right", thrift.I32, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.Right); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *FullStruct) writeField3(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Dummy", thrift.STRING, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteBinary([]byte(p.Dummy)); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) -} - -func (p *FullStruct) writeField4(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("InnerReq", thrift.STRUCT, 4); err != nil { - goto WriteFieldBeginError - } - if err := p.InnerReq.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) -} - -func (p *FullStruct) writeField5(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("status", thrift.I32, 5); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(int32(p.Status)); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) -} - -func (p *FullStruct) writeField6(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Str", thrift.STRING, 6); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Str); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) -} - -func (p *FullStruct) writeField7(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("enum_list", thrift.LIST, 7); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.I32, len(p.EnumList)); err != nil { - return err - } - for _, v := range p.EnumList { - if err := oprot.WriteI32(int32(v)); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) -} - -func (p *FullStruct) writeField8(oprot thrift.TProtocol) (err error) { - if p.IsSetStrmap() { - if err = oprot.WriteFieldBegin("Strmap", thrift.MAP, 8); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I32, thrift.STRING, len(p.Strmap)); err != nil { - return err - } - for k, v := range p.Strmap { - - if err := oprot.WriteI32(k); err != nil { - return err - } - - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) -} - -func (p *FullStruct) writeField9(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Int64", thrift.I64, 9); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI64(p.Int64); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) -} - -func (p *FullStruct) writeField10(oprot thrift.TProtocol) (err error) { - if p.IsSetIntList() { - if err = oprot.WriteFieldBegin("IntList", thrift.LIST, 10); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.I32, len(p.IntList)); err != nil { - return err - } - for _, v := range p.IntList { - if err := oprot.WriteI32(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 10 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) -} - -func (p *FullStruct) writeField11(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("localList", thrift.LIST, 11); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRUCT, len(p.LocalList)); err != nil { - return err - } - for _, v := range p.LocalList { - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) -} - -func (p *FullStruct) writeField12(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("StrLocalMap", thrift.MAP, 12); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRUCT, len(p.StrLocalMap)); err != nil { - return err - } - for k, v := range p.StrLocalMap { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 12 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 12 end error: ", p), err) -} - -func (p *FullStruct) writeField13(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("nestList", thrift.LIST, 13); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.LIST, len(p.NestList)); err != nil { - return err - } - for _, v := range p.NestList { - if err := oprot.WriteListBegin(thrift.I32, len(v)); err != nil { - return err - } - for _, v := range v { - if err := oprot.WriteI32(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 13 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) -} - -func (p *FullStruct) writeField14(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("required_ins", thrift.STRUCT, 14); err != nil { - goto WriteFieldBeginError - } - if err := p.RequiredIns.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) -} - -func (p *FullStruct) writeField16(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("nestMap", thrift.MAP, 16); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.LIST, len(p.NestMap)); err != nil { - return err - } - for k, v := range p.NestMap { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := oprot.WriteListBegin(thrift.STRING, len(v)); err != nil { - return err - } - for _, v := range v { - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 16 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 16 end error: ", p), err) -} - -func (p *FullStruct) writeField17(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("nestMap2", thrift.LIST, 17); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.MAP, len(p.NestMap2)); err != nil { - return err - } - for _, v := range p.NestMap2 { - if err := oprot.WriteMapBegin(thrift.STRING, thrift.I32, len(v)); err != nil { - return err - } - for k, v := range v { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := oprot.WriteI32(int32(v)); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 17 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 17 end error: ", p), err) -} - -func (p *FullStruct) writeField18(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("enum_map", thrift.MAP, 18); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I32, thrift.I32, len(p.EnumMap)); err != nil { - return err - } - for k, v := range p.EnumMap { - - if err := oprot.WriteI32(k); err != nil { - return err - } - - if err := oprot.WriteI32(int32(v)); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 18 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 18 end error: ", p), err) -} - -func (p *FullStruct) writeField19(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Strlist", thrift.LIST, 19); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRING, len(p.Strlist)); err != nil { - return err - } - for _, v := range p.Strlist { - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 19 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 19 end error: ", p), err) -} - -func (p *FullStruct) writeField20(oprot thrift.TProtocol) (err error) { - if p.IsSetOptionalIns() { - if err = oprot.WriteFieldBegin("optional_ins", thrift.STRUCT, 20); err != nil { - goto WriteFieldBeginError - } - if err := p.OptionalIns.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 20 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 20 end error: ", p), err) -} - -func (p *FullStruct) writeField21(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("AnotherInner", thrift.STRUCT, 21); err != nil { - goto WriteFieldBeginError - } - if err := p.AnotherInner.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 21 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 21 end error: ", p), err) -} - -func (p *FullStruct) writeField22(oprot thrift.TProtocol) (err error) { - if p.IsSetOptNilList() { - if err = oprot.WriteFieldBegin("opt_nil_list", thrift.LIST, 22); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRING, len(p.OptNilList)); err != nil { - return err - } - for _, v := range p.OptNilList { - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 22 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 22 end error: ", p), err) -} - -func (p *FullStruct) writeField23(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("nil_list", thrift.LIST, 23); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRING, len(p.NilList)); err != nil { - return err - } - for _, v := range p.NilList { - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 23 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 23 end error: ", p), err) -} - -func (p *FullStruct) writeField24(oprot thrift.TProtocol) (err error) { - if p.IsSetOptNilInsList() { - if err = oprot.WriteFieldBegin("opt_nil_ins_list", thrift.LIST, 24); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRUCT, len(p.OptNilInsList)); err != nil { - return err - } - for _, v := range p.OptNilInsList { - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 24 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 24 end error: ", p), err) -} - -func (p *FullStruct) writeField25(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("nil_ins_list", thrift.LIST, 25); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRUCT, len(p.NilInsList)); err != nil { - return err - } - for _, v := range p.NilInsList { - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 25 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 25 end error: ", p), err) -} - -func (p *FullStruct) writeField26(oprot thrift.TProtocol) (err error) { - if p.IsSetOptStatus() { - if err = oprot.WriteFieldBegin("opt_status", thrift.I32, 26); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(int32(*p.OptStatus)); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 26 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 26 end error: ", p), err) -} - -func (p *FullStruct) writeField27(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("enum_key_map", thrift.MAP, 27); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I32, thrift.STRUCT, len(p.EnumKeyMap)); err != nil { - return err - } - for k, v := range p.EnumKeyMap { - - if err := oprot.WriteI32(int32(k)); err != nil { - return err - } - - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 27 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 27 end error: ", p), err) -} - -func (p *FullStruct) writeField28(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("complex", thrift.MAP, 28); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I32, thrift.LIST, len(p.Complex)); err != nil { - return err - } - for k, v := range p.Complex { - - if err := oprot.WriteI32(int32(k)); err != nil { - return err - } - - if err := oprot.WriteListBegin(thrift.MAP, len(v)); err != nil { - return err - } - for _, v := range v { - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRUCT, len(v)); err != nil { - return err - } - for k, v := range v { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 28 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 28 end error: ", p), err) -} - -func (p *FullStruct) writeField29(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("i64Set", thrift.SET, 29); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteSetBegin(thrift.I64, len(p.I64Set)); err != nil { - return err - } - for i := 0; i < len(p.I64Set); i++ { - for j := i + 1; j < len(p.I64Set); j++ { - if func(tgt, src int64) bool { - if tgt != src { - return false - } - return true - }(p.I64Set[i], p.I64Set[j]) { - return thrift.PrependError("", fmt.Errorf("%T error writing set field: slice is not unique", p.I64Set[i])) - } - } - } - for _, v := range p.I64Set { - if err := oprot.WriteI64(v); err != nil { - return err - } - } - if err := oprot.WriteSetEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 29 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 29 end error: ", p), err) -} - -func (p *FullStruct) writeField30(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Int16", thrift.I16, 30); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI16(p.Int16); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 30 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 30 end error: ", p), err) -} - -func (p *FullStruct) writeField31(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("isSet", thrift.BOOL, 31); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteBool(p.IsSet); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 31 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 31 end error: ", p), err) -} - -func (p *FullStruct) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("FullStruct(%+v)", *p) -} + return fmt.Sprintf("FullStruct(%+v)", *p) +} func (p *FullStruct) DeepEqual(ano *FullStruct) bool { if p == ano { @@ -3630,6 +1192,39 @@ func (p *FullStruct) Field31DeepEqual(src bool) bool { return true } +var fieldIDToName_FullStruct = map[int16]string{ + 1: "Left", + 2: "Right", + 3: "Dummy", + 4: "InnerReq", + 5: "status", + 6: "Str", + 7: "enum_list", + 8: "Strmap", + 9: "Int64", + 10: "IntList", + 11: "localList", + 12: "StrLocalMap", + 13: "nestList", + 14: "required_ins", + 16: "nestMap", + 17: "nestMap2", + 18: "enum_map", + 19: "Strlist", + 20: "optional_ins", + 21: "AnotherInner", + 22: "opt_nil_list", + 23: "nil_list", + 24: "opt_nil_ins_list", + 25: "nil_ins_list", + 26: "opt_status", + 27: "enum_key_map", + 28: "complex", + 29: "i64Set", + 30: "Int16", + 31: "isSet", +} + type MixedStruct struct { Left int32 `thrift:"Left,1,required" frugal:"1,required,i32" json:"Left"` Dummy []byte `thrift:"Dummy,3" frugal:"3,default,binary" json:"Dummy"` @@ -3652,933 +1247,137 @@ func NewMixedStruct() *MixedStruct { } func (p *MixedStruct) InitDefault() { - *p = MixedStruct{} -} - -func (p *MixedStruct) GetLeft() (v int32) { - return p.Left -} - -func (p *MixedStruct) GetDummy() (v []byte) { - return p.Dummy -} - -func (p *MixedStruct) GetStr() (v string) { - return p.Str -} - -func (p *MixedStruct) GetEnumList() (v []HTTPStatus) { - return p.EnumList -} - -func (p *MixedStruct) GetInt64() (v int64) { - return p.Int64 -} - -var MixedStruct_IntList_DEFAULT []int32 - -func (p *MixedStruct) GetIntList() (v []int32) { - if !p.IsSetIntList() { - return MixedStruct_IntList_DEFAULT - } - return p.IntList -} - -func (p *MixedStruct) GetLocalList() (v []*Local) { - return p.LocalList -} - -func (p *MixedStruct) GetStrLocalMap() (v map[string]*Local) { - return p.StrLocalMap -} - -func (p *MixedStruct) GetNestList() (v [][]int32) { - return p.NestList -} - -var MixedStruct_RequiredIns_DEFAULT *Local - -func (p *MixedStruct) GetRequiredIns() (v *Local) { - if !p.IsSetRequiredIns() { - return MixedStruct_RequiredIns_DEFAULT - } - return p.RequiredIns -} - -var MixedStruct_OptionalIns_DEFAULT *Local - -func (p *MixedStruct) GetOptionalIns() (v *Local) { - if !p.IsSetOptionalIns() { - return MixedStruct_OptionalIns_DEFAULT - } - return p.OptionalIns -} - -var MixedStruct_AnotherInner_DEFAULT *Inner - -func (p *MixedStruct) GetAnotherInner() (v *Inner) { - if !p.IsSetAnotherInner() { - return MixedStruct_AnotherInner_DEFAULT - } - return p.AnotherInner -} - -func (p *MixedStruct) GetEnumKeyMap() (v map[HTTPStatus]*Local) { - return p.EnumKeyMap -} -func (p *MixedStruct) SetLeft(val int32) { - p.Left = val -} -func (p *MixedStruct) SetDummy(val []byte) { - p.Dummy = val -} -func (p *MixedStruct) SetStr(val string) { - p.Str = val -} -func (p *MixedStruct) SetEnumList(val []HTTPStatus) { - p.EnumList = val -} -func (p *MixedStruct) SetInt64(val int64) { - p.Int64 = val -} -func (p *MixedStruct) SetIntList(val []int32) { - p.IntList = val -} -func (p *MixedStruct) SetLocalList(val []*Local) { - p.LocalList = val -} -func (p *MixedStruct) SetStrLocalMap(val map[string]*Local) { - p.StrLocalMap = val -} -func (p *MixedStruct) SetNestList(val [][]int32) { - p.NestList = val -} -func (p *MixedStruct) SetRequiredIns(val *Local) { - p.RequiredIns = val -} -func (p *MixedStruct) SetOptionalIns(val *Local) { - p.OptionalIns = val -} -func (p *MixedStruct) SetAnotherInner(val *Inner) { - p.AnotherInner = val -} -func (p *MixedStruct) SetEnumKeyMap(val map[HTTPStatus]*Local) { - p.EnumKeyMap = val -} - -func (p *MixedStruct) CarryingUnknownFields() bool { - return len(p._unknownFields) > 0 -} - -var fieldIDToName_MixedStruct = map[int16]string{ - 1: "Left", - 3: "Dummy", - 6: "Str", - 7: "enum_list", - 9: "Int64", - 10: "IntList", - 11: "localList", - 12: "StrLocalMap", - 13: "nestList", - 14: "required_ins", - 20: "optional_ins", - 21: "AnotherInner", - 27: "enum_key_map", -} - -func (p *MixedStruct) IsSetIntList() bool { - return p.IntList != nil -} - -func (p *MixedStruct) IsSetRequiredIns() bool { - return p.RequiredIns != nil -} - -func (p *MixedStruct) IsSetOptionalIns() bool { - return p.OptionalIns != nil -} - -func (p *MixedStruct) IsSetAnotherInner() bool { - return p.AnotherInner != nil } -func (p *MixedStruct) Read(iprot thrift.TProtocol) (err error) { - var name string - var fieldTypeId thrift.TType - var fieldId int16 - var issetLeft bool = false - var issetRequiredIns bool = false - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - name, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - issetLeft = true - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.STRING { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.STRING { - if err = p.ReadField6(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 7: - if fieldTypeId == thrift.LIST { - if err = p.ReadField7(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 9: - if fieldTypeId == thrift.I64 { - if err = p.ReadField9(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 10: - if fieldTypeId == thrift.LIST { - if err = p.ReadField10(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 11: - if fieldTypeId == thrift.LIST { - if err = p.ReadField11(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 12: - if fieldTypeId == thrift.MAP { - if err = p.ReadField12(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 13: - if fieldTypeId == thrift.LIST { - if err = p.ReadField13(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 14: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField14(iprot); err != nil { - goto ReadFieldError - } - issetRequiredIns = true - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 20: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField20(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 21: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField21(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 27: - if fieldTypeId == thrift.MAP { - if err = p.ReadField27(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { - goto UnknownFieldsAppendError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - if !issetLeft { - fieldId = 1 - goto RequiredFieldNotSetError - } - - if !issetRequiredIns { - fieldId = 14 - goto RequiredFieldNotSetError - } - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MixedStruct[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -UnknownFieldsAppendError: - return thrift.PrependError(fmt.Sprintf("%T append unknown field(name:%s type:%d id:%d) error: ", p, name, fieldTypeId, fieldId), err) +func (p *MixedStruct) GetLeft() (v int32) { + return p.Left +} -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -RequiredFieldNotSetError: - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_MixedStruct[fieldId])) +func (p *MixedStruct) GetDummy() (v []byte) { + return p.Dummy } -func (p *MixedStruct) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.Left = v - } - return nil +func (p *MixedStruct) GetStr() (v string) { + return p.Str } -func (p *MixedStruct) ReadField3(iprot thrift.TProtocol) error { - if v, err := iprot.ReadBinary(); err != nil { - return err - } else { - p.Dummy = []byte(v) - } - return nil +func (p *MixedStruct) GetEnumList() (v []HTTPStatus) { + return p.EnumList } -func (p *MixedStruct) ReadField6(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Str = v - } - return nil +func (p *MixedStruct) GetInt64() (v int64) { + return p.Int64 } -func (p *MixedStruct) ReadField7(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.EnumList = make([]HTTPStatus, 0, size) - for i := 0; i < size; i++ { - var _elem HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _elem = HTTPStatus(v) - } +var MixedStruct_IntList_DEFAULT []int32 - p.EnumList = append(p.EnumList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err +func (p *MixedStruct) GetIntList() (v []int32) { + if !p.IsSetIntList() { + return MixedStruct_IntList_DEFAULT } - return nil + return p.IntList } -func (p *MixedStruct) ReadField9(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - p.Int64 = v - } - return nil +func (p *MixedStruct) GetLocalList() (v []*Local) { + return p.LocalList } -func (p *MixedStruct) ReadField10(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.IntList = make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _elem = v - } +func (p *MixedStruct) GetStrLocalMap() (v map[string]*Local) { + return p.StrLocalMap +} - p.IntList = append(p.IntList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - return nil +func (p *MixedStruct) GetNestList() (v [][]int32) { + return p.NestList } -func (p *MixedStruct) ReadField11(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.LocalList = make([]*Local, 0, size) - for i := 0; i < size; i++ { - _elem := NewLocal() - if err := _elem.Read(iprot); err != nil { - return err - } +var MixedStruct_RequiredIns_DEFAULT *Local - p.LocalList = append(p.LocalList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err +func (p *MixedStruct) GetRequiredIns() (v *Local) { + if !p.IsSetRequiredIns() { + return MixedStruct_RequiredIns_DEFAULT } - return nil + return p.RequiredIns } -func (p *MixedStruct) ReadField12(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.StrLocalMap = make(map[string]*Local, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - _val := NewLocal() - if err := _val.Read(iprot); err != nil { - return err - } +var MixedStruct_OptionalIns_DEFAULT *Local - p.StrLocalMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err +func (p *MixedStruct) GetOptionalIns() (v *Local) { + if !p.IsSetOptionalIns() { + return MixedStruct_OptionalIns_DEFAULT } - return nil + return p.OptionalIns } -func (p *MixedStruct) ReadField13(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - p.NestList = make([][]int32, 0, size) - for i := 0; i < size; i++ { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - _elem := make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem1 int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _elem1 = v - } - - _elem = append(_elem, _elem1) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } +var MixedStruct_AnotherInner_DEFAULT *Inner - p.NestList = append(p.NestList, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err +func (p *MixedStruct) GetAnotherInner() (v *Inner) { + if !p.IsSetAnotherInner() { + return MixedStruct_AnotherInner_DEFAULT } - return nil + return p.AnotherInner } -func (p *MixedStruct) ReadField14(iprot thrift.TProtocol) error { - p.RequiredIns = NewLocal() - if err := p.RequiredIns.Read(iprot); err != nil { - return err - } - return nil +func (p *MixedStruct) GetEnumKeyMap() (v map[HTTPStatus]*Local) { + return p.EnumKeyMap } - -func (p *MixedStruct) ReadField20(iprot thrift.TProtocol) error { - p.OptionalIns = NewLocal() - if err := p.OptionalIns.Read(iprot); err != nil { - return err - } - return nil +func (p *MixedStruct) SetLeft(val int32) { + p.Left = val } - -func (p *MixedStruct) ReadField21(iprot thrift.TProtocol) error { - p.AnotherInner = NewInner() - if err := p.AnotherInner.Read(iprot); err != nil { - return err - } - return nil +func (p *MixedStruct) SetDummy(val []byte) { + p.Dummy = val } - -func (p *MixedStruct) ReadField27(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.EnumKeyMap = make(map[HTTPStatus]*Local, size) - for i := 0; i < size; i++ { - var _key HTTPStatus - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _key = HTTPStatus(v) - } - _val := NewLocal() - if err := _val.Read(iprot); err != nil { - return err - } - - p.EnumKeyMap[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil +func (p *MixedStruct) SetStr(val string) { + p.Str = val } - -func (p *MixedStruct) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("MixedStruct"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - if err = p.writeField6(oprot); err != nil { - fieldId = 6 - goto WriteFieldError - } - if err = p.writeField7(oprot); err != nil { - fieldId = 7 - goto WriteFieldError - } - if err = p.writeField9(oprot); err != nil { - fieldId = 9 - goto WriteFieldError - } - if err = p.writeField10(oprot); err != nil { - fieldId = 10 - goto WriteFieldError - } - if err = p.writeField11(oprot); err != nil { - fieldId = 11 - goto WriteFieldError - } - if err = p.writeField12(oprot); err != nil { - fieldId = 12 - goto WriteFieldError - } - if err = p.writeField13(oprot); err != nil { - fieldId = 13 - goto WriteFieldError - } - if err = p.writeField14(oprot); err != nil { - fieldId = 14 - goto WriteFieldError - } - if err = p.writeField20(oprot); err != nil { - fieldId = 20 - goto WriteFieldError - } - if err = p.writeField21(oprot); err != nil { - fieldId = 21 - goto WriteFieldError - } - if err = p.writeField27(oprot); err != nil { - fieldId = 27 - goto WriteFieldError - } - - if err = p._unknownFields.Write(oprot); err != nil { - goto UnknownFieldsWriteError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -UnknownFieldsWriteError: - return thrift.PrependError(fmt.Sprintf("%T write unknown fields error: ", p), err) +func (p *MixedStruct) SetEnumList(val []HTTPStatus) { + p.EnumList = val } - -func (p *MixedStruct) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Left", thrift.I32, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.Left); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +func (p *MixedStruct) SetInt64(val int64) { + p.Int64 = val } - -func (p *MixedStruct) writeField3(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Dummy", thrift.STRING, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteBinary([]byte(p.Dummy)); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +func (p *MixedStruct) SetIntList(val []int32) { + p.IntList = val } - -func (p *MixedStruct) writeField6(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Str", thrift.STRING, 6); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Str); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) +func (p *MixedStruct) SetLocalList(val []*Local) { + p.LocalList = val } - -func (p *MixedStruct) writeField7(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("enum_list", thrift.LIST, 7); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.I32, len(p.EnumList)); err != nil { - return err - } - for _, v := range p.EnumList { - if err := oprot.WriteI32(int32(v)); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) +func (p *MixedStruct) SetStrLocalMap(val map[string]*Local) { + p.StrLocalMap = val } - -func (p *MixedStruct) writeField9(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Int64", thrift.I64, 9); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI64(p.Int64); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) +func (p *MixedStruct) SetNestList(val [][]int32) { + p.NestList = val } - -func (p *MixedStruct) writeField10(oprot thrift.TProtocol) (err error) { - if p.IsSetIntList() { - if err = oprot.WriteFieldBegin("IntList", thrift.LIST, 10); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.I32, len(p.IntList)); err != nil { - return err - } - for _, v := range p.IntList { - if err := oprot.WriteI32(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 10 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) +func (p *MixedStruct) SetRequiredIns(val *Local) { + p.RequiredIns = val } - -func (p *MixedStruct) writeField11(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("localList", thrift.LIST, 11); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRUCT, len(p.LocalList)); err != nil { - return err - } - for _, v := range p.LocalList { - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) +func (p *MixedStruct) SetOptionalIns(val *Local) { + p.OptionalIns = val } - -func (p *MixedStruct) writeField12(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("StrLocalMap", thrift.MAP, 12); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRUCT, len(p.StrLocalMap)); err != nil { - return err - } - for k, v := range p.StrLocalMap { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 12 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 12 end error: ", p), err) +func (p *MixedStruct) SetAnotherInner(val *Inner) { + p.AnotherInner = val } - -func (p *MixedStruct) writeField13(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("nestList", thrift.LIST, 13); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.LIST, len(p.NestList)); err != nil { - return err - } - for _, v := range p.NestList { - if err := oprot.WriteListBegin(thrift.I32, len(v)); err != nil { - return err - } - for _, v := range v { - if err := oprot.WriteI32(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 13 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) +func (p *MixedStruct) SetEnumKeyMap(val map[HTTPStatus]*Local) { + p.EnumKeyMap = val } -func (p *MixedStruct) writeField14(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("required_ins", thrift.STRUCT, 14); err != nil { - goto WriteFieldBeginError - } - if err := p.RequiredIns.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) +func (p *MixedStruct) CarryingUnknownFields() bool { + return len(p._unknownFields) > 0 } -func (p *MixedStruct) writeField20(oprot thrift.TProtocol) (err error) { - if p.IsSetOptionalIns() { - if err = oprot.WriteFieldBegin("optional_ins", thrift.STRUCT, 20); err != nil { - goto WriteFieldBeginError - } - if err := p.OptionalIns.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 20 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 20 end error: ", p), err) +func (p *MixedStruct) IsSetIntList() bool { + return p.IntList != nil } -func (p *MixedStruct) writeField21(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("AnotherInner", thrift.STRUCT, 21); err != nil { - goto WriteFieldBeginError - } - if err := p.AnotherInner.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 21 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 21 end error: ", p), err) +func (p *MixedStruct) IsSetRequiredIns() bool { + return p.RequiredIns != nil } -func (p *MixedStruct) writeField27(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("enum_key_map", thrift.MAP, 27); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.I32, thrift.STRUCT, len(p.EnumKeyMap)); err != nil { - return err - } - for k, v := range p.EnumKeyMap { - - if err := oprot.WriteI32(int32(k)); err != nil { - return err - } +func (p *MixedStruct) IsSetOptionalIns() bool { + return p.OptionalIns != nil +} - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 27 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 27 end error: ", p), err) +func (p *MixedStruct) IsSetAnotherInner() bool { + return p.AnotherInner != nil } func (p *MixedStruct) String() string { @@ -4770,6 +1569,22 @@ func (p *MixedStruct) Field27DeepEqual(src map[HTTPStatus]*Local) bool { return true } +var fieldIDToName_MixedStruct = map[int16]string{ + 1: "Left", + 3: "Dummy", + 6: "Str", + 7: "enum_list", + 9: "Int64", + 10: "IntList", + 11: "localList", + 12: "StrLocalMap", + 13: "nestList", + 14: "required_ins", + 20: "optional_ins", + 21: "AnotherInner", + 27: "enum_key_map", +} + type EmptyStruct struct { _unknownFields unknown.Fields } @@ -4779,89 +1594,12 @@ func NewEmptyStruct() *EmptyStruct { } func (p *EmptyStruct) InitDefault() { - *p = EmptyStruct{} } func (p *EmptyStruct) CarryingUnknownFields() bool { return len(p._unknownFields) > 0 } -var fieldIDToName_EmptyStruct = map[int16]string{} - -func (p *EmptyStruct) Read(iprot thrift.TProtocol) (err error) { - var name string - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - name, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - default: - if err = p._unknownFields.Append(iprot, name, fieldTypeId, fieldId); err != nil { - goto UnknownFieldsAppendError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -UnknownFieldsAppendError: - return thrift.PrependError(fmt.Sprintf("%T append unknown field(name:%s type:%d id:%d) error: ", p, name, fieldTypeId, fieldId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *EmptyStruct) Write(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteStructBegin("EmptyStruct"); err != nil { - goto WriteStructBeginError - } - if p != nil { - - if err = p._unknownFields.Write(oprot); err != nil { - goto UnknownFieldsWriteError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -UnknownFieldsWriteError: - return thrift.PrependError(fmt.Sprintf("%T write unknown fields error: ", p), err) -} - func (p *EmptyStruct) String() string { if p == nil { return "" @@ -4877,3 +1615,5 @@ func (p *EmptyStruct) DeepEqual(ano *EmptyStruct) bool { } return true } + +var fieldIDToName_EmptyStruct = map[int16]string{} diff --git a/pkg/protocol/bthrift/test/unknown_test.go b/pkg/protocol/bthrift/test/unknown_test.go index 1a24416924..beeae56d84 100644 --- a/pkg/protocol/bthrift/test/unknown_test.go +++ b/pkg/protocol/bthrift/test/unknown_test.go @@ -24,8 +24,6 @@ import ( tt "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/protocol/bthrift/test/kitex_gen/test" - "github.com/cloudwego/kitex/pkg/remote" - codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" ) var fullReq *test.FullStruct @@ -109,24 +107,6 @@ func TestOnlyUnknownField(t *testing.T) { tt.Assert(t, writeL == l) tt.Assert(t, bytes.Equal(buf, unknownBuf)) - // thrift read/write without fast api - trans := remote.NewReaderWriterBuffer(-1) - prot := codecThrift.NewBinaryProtocol(trans) - err = fullReq.Write(prot) - tt.Assert(t, err == nil) - unknown1 := &test.EmptyStruct{} - err = unknown1.Read(prot) - tt.Assert(t, err == nil) - tt.Assert(t, unknown.BLength() == unknown1.BLength()) - trans = remote.NewReaderWriterBuffer(-1) - prot = codecThrift.NewBinaryProtocol(trans) - err = unknown1.Write(prot) - tt.Assert(t, err == nil) - unknown1 = &test.EmptyStruct{} - err = unknown1.Read(prot) - tt.Assert(t, err == nil) - tt.Assert(t, unknown.BLength() == unknown1.BLength()) - // test get unknown fields fields, err := bthrift.GetUnknownFields(unknown) tt.Assert(t, err == nil) @@ -161,24 +141,6 @@ func TestPartialUnknownField(t *testing.T) { tt.Assert(t, err == nil) tt.Assert(t, ll == unknownL) tt.Assert(t, compare1.DeepEqual(compare)) - - // thrift read/write without fast api - trans := remote.NewReaderWriterBuffer(-1) - prot := codecThrift.NewBinaryProtocol(trans) - err = fullReq.Write(prot) - tt.Assert(t, err == nil) - unknown1 := &test.MixedStruct{} - err = unknown1.Read(prot) - tt.Assert(t, err == nil) - tt.Assert(t, unknown.BLength() == unknown1.BLength()) - trans = remote.NewReaderWriterBuffer(-1) - prot = codecThrift.NewBinaryProtocol(trans) - err = unknown1.Write(prot) - tt.Assert(t, err == nil) - unknown1 = &test.MixedStruct{} - err = unknown1.Read(prot) - tt.Assert(t, err == nil) - tt.Assert(t, unknown.BLength() == unknown1.BLength()) } func TestNoUnknownField(t *testing.T) { diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 537894b97f..290f0ef1ae 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -51,7 +51,7 @@ func init() { if err := versions.RegisterMinDepVersion( &versions.MinDepVersion{ RefPath: "github.com/cloudwego/kitex", - Version: "v0.9.0", + Version: "v0.11.0", }, ); err != nil { log.Warn(err) diff --git a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go index 0552bd5cd5..f866c490c9 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go +++ b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go @@ -29,7 +29,7 @@ var ( _ = (*bytes.Buffer)(nil) {{- UseStdLibrary "bytes"}} _ = (*strings.Builder)(nil) {{- UseStdLibrary "strings"}} _ = reflect.Type(nil) {{- UseStdLibrary "reflect"}} - _ = thrift.TProtocol(nil) {{- UseStdLibrary "thrift"}} + _ = thrift.TProtocol(nil) {{- UseLib (ImportPathTo "pkg/protocol/bthrift/apache") "thrift"}} {{- if GenerateFastAPIs}} {{- UseLib (ImportPathTo "pkg/protocol/bthrift") ""}} _ = bthrift.BinaryWriter(nil) diff --git a/tool/internal_pkg/pluginmode/thriftgo/patcher.go b/tool/internal_pkg/pluginmode/thriftgo/patcher.go index d86fe37f72..ce00548fb8 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/patcher.go +++ b/tool/internal_pkg/pluginmode/thriftgo/patcher.go @@ -386,8 +386,15 @@ func getBashPath() string { func (p *patcher) extractLocalLibs(imports []util.Import) []util.Import { ret := make([]util.Import, 0) prefix := p.module + "/" + kitexPkgPath := generator.ImportPathTo("pkg") // remove std libs and thrift to prevent duplicate import. for _, v := range imports { + if strings.HasPrefix(prefix, kitexPkgPath) { + // fix bad the case like: `undefined: bthrift.KitexUnusedProtection` + // when we generate code in kitex repo. + // we may never ref to other generate code in kitex repo, if do fix me. + continue + } // local packages if strings.HasPrefix(v.Path, prefix) { ret = append(ret, v) diff --git a/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go index 31b2ef32d9..7dd9fad917 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go +++ b/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go @@ -145,19 +145,19 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { {{- end}} return offset, nil ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) {{- if gt (len .Fields) 0}} ReadFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_{{$TypeName}}[fieldId]), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_{{$TypeName}}[fieldId]), err) {{- end}} SkipFieldError: - return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) {{- if $NeedRequiredFieldNotSetError }} RequiredFieldNotSetError: return offset, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_{{$TypeName}}[fieldId])) diff --git a/version.go b/version.go index 93de9aca6d..96c9e79792 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package kitex // Name and Version info of this framework, used for statistics and debug const ( Name = "Kitex" - Version = "v0.10.1" + Version = "v0.11.0" ) From a99173f6c78c6883fc22998d2a6bbb15479d2c94 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Tue, 9 Jul 2024 14:19:48 +0800 Subject: [PATCH 09/70] chore: fixed undefined KitexUnusedProtection (#1428) --- tool/internal_pkg/pluginmode/thriftgo/patcher.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tool/internal_pkg/pluginmode/thriftgo/patcher.go b/tool/internal_pkg/pluginmode/thriftgo/patcher.go index ce00548fb8..ea9e263b9d 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/patcher.go +++ b/tool/internal_pkg/pluginmode/thriftgo/patcher.go @@ -389,7 +389,7 @@ func (p *patcher) extractLocalLibs(imports []util.Import) []util.Import { kitexPkgPath := generator.ImportPathTo("pkg") // remove std libs and thrift to prevent duplicate import. for _, v := range imports { - if strings.HasPrefix(prefix, kitexPkgPath) { + if strings.HasPrefix(v.Path, kitexPkgPath) { // fix bad the case like: `undefined: bthrift.KitexUnusedProtection` // when we generate code in kitex repo. // we may never ref to other generate code in kitex repo, if do fix me. From d163f5b8e3d05707f46a758a2a1f5537d8d10f72 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 11 Jul 2024 11:10:06 +0800 Subject: [PATCH 10/70] test: works without apache code (#1429) --- internal/mocks/thrift/gen.sh | 2 +- internal/mocks/thrift/test.go | 954 +------------------- internal/mocks/thrift/utils.go | 80 ++ pkg/generic/map_test/generic_init.go | 8 +- pkg/remote/codec/thrift/binary_protocol.go | 16 +- pkg/remote/codec/thrift/thrift_data_test.go | 12 +- pkg/remote/codec/thrift/thrift_test.go | 12 +- pkg/utils/thrift_test.go | 12 +- 8 files changed, 151 insertions(+), 945 deletions(-) create mode 100644 internal/mocks/thrift/utils.go diff --git a/internal/mocks/thrift/gen.sh b/internal/mocks/thrift/gen.sh index df178c759b..dd4a4915a0 100755 --- a/internal/mocks/thrift/gen.sh +++ b/internal/mocks/thrift/gen.sh @@ -1,5 +1,5 @@ #! /bin/bash -kitex -module github.com/cloudwego/kitex -gen-path .. ./test.thrift +kitex -thrift no_default_serdes -module github.com/cloudwego/kitex -gen-path .. ./test.thrift rm -rf ./mock # not in use, rm it diff --git a/internal/mocks/thrift/test.go b/internal/mocks/thrift/test.go index 6eba70afe6..ccc8a6f27b 100644 --- a/internal/mocks/thrift/test.go +++ b/internal/mocks/thrift/test.go @@ -14,14 +14,13 @@ * limitations under the License. */ -// Code generated by thriftgo (0.3.13). DO NOT EDIT. +// Code generated by thriftgo (0.3.14). DO NOT EDIT. package thrift import ( "context" "fmt" - "github.com/apache/thrift/lib/go/thrift" "strings" ) @@ -59,260 +58,11 @@ func (p *MockReq) SetStrList(val []string) { p.StrList = val } -var fieldIDToName_MockReq = map[int16]string{ - 1: "Msg", - 2: "strMap", - 3: "strList", -} - -func (p *MockReq) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRING { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - case 2: - if fieldTypeId == thrift.MAP { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - case 3: - if fieldTypeId == thrift.LIST { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockReq[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockReq) ReadField1(iprot thrift.TProtocol) error { - - var _field string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _field = v - } - p.Msg = _field - return nil -} -func (p *MockReq) ReadField2(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - _field := make(map[string]string, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - - var _val string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _val = v - } - - _field[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - p.StrMap = _field - return nil -} -func (p *MockReq) ReadField3(iprot thrift.TProtocol) error { - _, size, err := iprot.ReadListBegin() - if err != nil { - return err - } - _field := make([]string, 0, size) - for i := 0; i < size; i++ { - - var _elem string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _elem = v - } - - _field = append(_field, _elem) - } - if err := iprot.ReadListEnd(); err != nil { - return err - } - p.StrList = _field - return nil -} - -func (p *MockReq) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("MockReq"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockReq) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Msg", thrift.STRING, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Msg); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *MockReq) writeField2(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("strMap", thrift.MAP, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.StrMap)); err != nil { - return err - } - for k, v := range p.StrMap { - if err := oprot.WriteString(k); err != nil { - return err - } - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *MockReq) writeField3(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("strList", thrift.LIST, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteListBegin(thrift.STRING, len(p.StrList)); err != nil { - return err - } - for _, v := range p.StrList { - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) -} - func (p *MockReq) String() string { if p == nil { return "" } return fmt.Sprintf("MockReq(%+v)", *p) - } func (p *MockReq) DeepEqual(ano *MockReq) bool { @@ -367,6 +117,12 @@ func (p *MockReq) Field3DeepEqual(src []string) bool { return true } +var fieldIDToName_MockReq = map[int16]string{ + 1: "Msg", + 2: "strMap", + 3: "strList", +} + type Exception struct { Code int32 `thrift:"code,1" frugal:"1,default,i32" json:"code"` Msg string `thrift:"msg,255" frugal:"255,default,string" json:"msg"` @@ -393,170 +149,11 @@ func (p *Exception) SetMsg(val string) { p.Msg = val } -var fieldIDToName_Exception = map[int16]string{ - 1: "code", - 255: "msg", -} - -func (p *Exception) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - case 255: - if fieldTypeId == thrift.STRING { - if err = p.ReadField255(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Exception[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Exception) ReadField1(iprot thrift.TProtocol) error { - - var _field int32 - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - _field = v - } - p.Code = _field - return nil -} -func (p *Exception) ReadField255(iprot thrift.TProtocol) error { - - var _field string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _field = v - } - p.Msg = _field - return nil -} - -func (p *Exception) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Exception"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField255(oprot); err != nil { - fieldId = 255 - goto WriteFieldError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *Exception) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("code", thrift.I32, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.Code); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *Exception) writeField255(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("msg", thrift.STRING, 255); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Msg); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 255 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 255 end error: ", p), err) -} - func (p *Exception) String() string { if p == nil { return "" } return fmt.Sprintf("Exception(%+v)", *p) - } func (p *Exception) Error() string { return p.String() @@ -592,6 +189,11 @@ func (p *Exception) Field255DeepEqual(src string) bool { return true } +var fieldIDToName_Exception = map[int16]string{ + 1: "code", + 255: "msg", +} + type Mock interface { Test(ctx context.Context, req *MockReq) (r string, err error) @@ -621,130 +223,15 @@ func (p *MockTestArgs) SetReq(val *MockReq) { p.Req = val } -var fieldIDToName_MockTestArgs = map[int16]string{ - 1: "req", -} - func (p *MockTestArgs) IsSetReq() bool { return p.Req != nil } -func (p *MockTestArgs) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestArgs[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockTestArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewMockReq() - if err := _field.Read(iprot); err != nil { - return err - } - p.Req = _field - return nil -} - -func (p *MockTestArgs) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Test_args"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockTestArgs) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { - goto WriteFieldBeginError - } - if err := p.Req.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - func (p *MockTestArgs) String() string { if p == nil { return "" } return fmt.Sprintf("MockTestArgs(%+v)", *p) - } func (p *MockTestArgs) DeepEqual(ano *MockTestArgs) bool { @@ -767,6 +254,10 @@ func (p *MockTestArgs) Field1DeepEqual(src *MockReq) bool { return true } +var fieldIDToName_MockTestArgs = map[int16]string{ + 1: "req", +} + type MockTestResult struct { Success *string `thrift:"success,0,optional" frugal:"0,optional,string" json:"success,omitempty"` } @@ -790,135 +281,15 @@ func (p *MockTestResult) SetSuccess(x interface{}) { p.Success = x.(*string) } -var fieldIDToName_MockTestResult = map[int16]string{ - 0: "success", -} - func (p *MockTestResult) IsSetSuccess() bool { return p.Success != nil } -func (p *MockTestResult) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 0: - if fieldTypeId == thrift.STRING { - if err = p.ReadField0(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestResult[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockTestResult) ReadField0(iprot thrift.TProtocol) error { - - var _field *string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _field = &v - } - p.Success = _field - return nil -} - -func (p *MockTestResult) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Test_result"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField0(oprot); err != nil { - fieldId = 0 - goto WriteFieldError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockTestResult) writeField0(oprot thrift.TProtocol) (err error) { - if p.IsSetSuccess() { - if err = oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(*p.Success); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) -} - func (p *MockTestResult) String() string { if p == nil { return "" } return fmt.Sprintf("MockTestResult(%+v)", *p) - } func (p *MockTestResult) DeepEqual(ano *MockTestResult) bool { @@ -946,6 +317,10 @@ func (p *MockTestResult) Field0DeepEqual(src *string) bool { return true } +var fieldIDToName_MockTestResult = map[int16]string{ + 0: "success", +} + type MockExceptionTestArgs struct { Req *MockReq `thrift:"req,1" frugal:"1,default,MockReq" json:"req"` } @@ -969,130 +344,15 @@ func (p *MockExceptionTestArgs) SetReq(val *MockReq) { p.Req = val } -var fieldIDToName_MockExceptionTestArgs = map[int16]string{ - 1: "req", -} - func (p *MockExceptionTestArgs) IsSetReq() bool { return p.Req != nil } -func (p *MockExceptionTestArgs) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestArgs[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockExceptionTestArgs) ReadField1(iprot thrift.TProtocol) error { - _field := NewMockReq() - if err := _field.Read(iprot); err != nil { - return err - } - p.Req = _field - return nil -} - -func (p *MockExceptionTestArgs) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("ExceptionTest_args"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockExceptionTestArgs) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("req", thrift.STRUCT, 1); err != nil { - goto WriteFieldBeginError - } - if err := p.Req.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - func (p *MockExceptionTestArgs) String() string { if p == nil { return "" } return fmt.Sprintf("MockExceptionTestArgs(%+v)", *p) - } func (p *MockExceptionTestArgs) DeepEqual(ano *MockExceptionTestArgs) bool { @@ -1115,6 +375,10 @@ func (p *MockExceptionTestArgs) Field1DeepEqual(src *MockReq) bool { return true } +var fieldIDToName_MockExceptionTestArgs = map[int16]string{ + 1: "req", +} + type MockExceptionTestResult struct { Success *string `thrift:"success,0,optional" frugal:"0,optional,string" json:"success,omitempty"` Err *Exception `thrift:"err,1,optional" frugal:"1,optional,Exception" json:"err,omitempty"` @@ -1151,11 +415,6 @@ func (p *MockExceptionTestResult) SetErr(val *Exception) { p.Err = val } -var fieldIDToName_MockExceptionTestResult = map[int16]string{ - 0: "success", - 1: "err", -} - func (p *MockExceptionTestResult) IsSetSuccess() bool { return p.Success != nil } @@ -1164,166 +423,11 @@ func (p *MockExceptionTestResult) IsSetErr() bool { return p.Err != nil } -func (p *MockExceptionTestResult) Read(iprot thrift.TProtocol) (err error) { - - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 0: - if fieldTypeId == thrift.STRING { - if err = p.ReadField0(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - case 1: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestResult[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *MockExceptionTestResult) ReadField0(iprot thrift.TProtocol) error { - - var _field *string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _field = &v - } - p.Success = _field - return nil -} -func (p *MockExceptionTestResult) ReadField1(iprot thrift.TProtocol) error { - _field := NewException() - if err := _field.Read(iprot); err != nil { - return err - } - p.Err = _field - return nil -} - -func (p *MockExceptionTestResult) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("ExceptionTest_result"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField0(oprot); err != nil { - fieldId = 0 - goto WriteFieldError - } - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *MockExceptionTestResult) writeField0(oprot thrift.TProtocol) (err error) { - if p.IsSetSuccess() { - if err = oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(*p.Success); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 0 end error: ", p), err) -} - -func (p *MockExceptionTestResult) writeField1(oprot thrift.TProtocol) (err error) { - if p.IsSetErr() { - if err = oprot.WriteFieldBegin("err", thrift.STRUCT, 1); err != nil { - goto WriteFieldBeginError - } - if err := p.Err.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - func (p *MockExceptionTestResult) String() string { if p == nil { return "" } return fmt.Sprintf("MockExceptionTestResult(%+v)", *p) - } func (p *MockExceptionTestResult) DeepEqual(ano *MockExceptionTestResult) bool { @@ -1360,3 +464,13 @@ func (p *MockExceptionTestResult) Field1DeepEqual(src *Exception) bool { } return true } + +var fieldIDToName_MockExceptionTestResult = map[int16]string{ + 0: "success", + 1: "err", +} + +// exceptions of methods in Mock. +var ( + _ error = (*Exception)(nil) +) diff --git a/internal/mocks/thrift/utils.go b/internal/mocks/thrift/utils.go new file mode 100644 index 0000000000..6c4cabe874 --- /dev/null +++ b/internal/mocks/thrift/utils.go @@ -0,0 +1,80 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "io" + + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" +) + +// ApacheCodecAdapter converts a fastcodec struct to apache codec +type ApacheCodecAdapter struct { + p bthrift.ThriftFastCodec +} + +// Write implements thrift.TStruct +func (p ApacheCodecAdapter) Write(tp thrift.TProtocol) error { + b := make([]byte, p.p.BLength()) + b = b[:p.p.FastWriteNocopy(b, nil)] + trans := tp.Transport() + if t, ok := trans.(remoteByteBuffer); ok { + // remote.ByteBuffer not always implement io.Writer ... + // can only use WriteBinary + _, err := t.WriteBinary(b) + return err + } + _, err := tp.Transport().Write(b) + return err +} + +type remoteByteBuffer interface { + ReadableLen() (n int) + Next(n int) (p []byte, err error) + WriteBinary(b []byte) (n int, err error) +} + +// Read implements thrift.TStruct +func (p ApacheCodecAdapter) Read(tp thrift.TProtocol) error { + var err error + var b []byte + trans := tp.Transport() + if t, ok := trans.(remoteByteBuffer); ok { + // remote.ByteBuffer not always implement io.Reader ... + // can only use Next() + b, err = t.Next(t.ReadableLen()) + } else { + n := trans.RemainingBytes() + b = make([]byte, n) + _, err = io.ReadFull(trans, b) + } + if err == nil { + _, err = p.p.FastRead(b) + } + return err +} + +// ToApacheCodec converts a bthrift.ThriftFastCodec to thrift.TStruct +func ToApacheCodec(p bthrift.ThriftFastCodec) thrift.TStruct { + return ApacheCodecAdapter{p: p} +} + +// UnpackApacheCodec unpacks ToApacheCodec +func UnpackApacheCodec(v interface{}) interface{} { + return v.(ApacheCodecAdapter).p +} diff --git a/pkg/generic/map_test/generic_init.go b/pkg/generic/map_test/generic_init.go index dd5d466025..aae9490f27 100644 --- a/pkg/generic/map_test/generic_init.go +++ b/pkg/generic/map_test/generic_init.go @@ -204,16 +204,16 @@ func serviceInfo() *serviceinfo.ServiceInfo { } func newMockTestArgs() interface{} { - return kt.NewMockTestArgs() + return kt.ToApacheCodec(kt.NewMockTestArgs()) } func newMockTestResult() interface{} { - return kt.NewMockTestResult() + return kt.ToApacheCodec(kt.NewMockTestResult()) } func testHandler(ctx context.Context, handler, arg, result interface{}) error { - realArg := arg.(*kt.MockTestArgs) - realResult := result.(*kt.MockTestResult) + realArg := kt.UnpackApacheCodec(arg).(*kt.MockTestArgs) + realResult := kt.UnpackApacheCodec(result).(*kt.MockTestResult) success, err := handler.(kt.Mock).Test(ctx, realArg.Req) if err != nil { return err diff --git a/pkg/remote/codec/thrift/binary_protocol.go b/pkg/remote/codec/thrift/binary_protocol.go index eae7608faa..bc25cf1929 100644 --- a/pkg/remote/codec/thrift/binary_protocol.go +++ b/pkg/remote/codec/thrift/binary_protocol.go @@ -495,10 +495,22 @@ func (p *BinaryProtocol) Skip(fieldType thrift.TType) (err error) { return thrift.SkipDefaultDepth(p, fieldType) } +// ttransportByteBuffer ... +// for exposing remote.ByteBuffer via p.Transport(), +// mainly for testing purpose, see internal/mocks/thrift/utils.go +type ttransportByteBuffer struct { + remote.ByteBuffer +} + +func (ttransportByteBuffer) Close() error { panic("not implemented") } +func (ttransportByteBuffer) Flush(ctx context.Context) (err error) { panic("not implemented") } +func (ttransportByteBuffer) IsOpen() bool { panic("not implemented") } +func (ttransportByteBuffer) Open() error { panic("not implemented") } +func (ttransportByteBuffer) RemainingBytes() uint64 { panic("not implemented") } + // Transport ... func (p *BinaryProtocol) Transport() thrift.TTransport { - // not support - return nil + return ttransportByteBuffer{p.trans} } // ByteBuffer ... diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index 7ba4a5ae57..ec6a14c1cb 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -49,7 +49,7 @@ func TestMarshalBasicThriftData(t *testing.T) { t.Run("valid-data", func(t *testing.T) { transport := thrift.NewTMemoryBufferLen(1024) tProt := thrift.NewTBinaryProtocol(transport, true, true) - err := marshalBasicThriftData(context.Background(), tProt, mockReq, "", -1) + err := marshalBasicThriftData(context.Background(), tProt, mocks.ToApacheCodec(mockReq), "", -1) test.Assert(t, err == nil, err) result := transport.Bytes() test.Assert(t, reflect.DeepEqual(result, mockReqThrift), result) @@ -68,7 +68,7 @@ func TestMarshalThriftData(t *testing.T) { test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) }) t.Run("BasicCodec", func(t *testing.T) { - buf, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), mockReq) + buf, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), mocks.ToApacheCodec(mockReq)) test.Assert(t, err == nil, err) test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) }) @@ -79,19 +79,19 @@ func Test_decodeBasicThriftData(t *testing.T) { t.Run("empty-input", func(t *testing.T) { req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{})) - err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) + err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, mocks.ToApacheCodec(req)) test.Assert(t, err != nil, err) }) t.Run("invalid-input", func(t *testing.T) { req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{0xff})) - err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) + err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, mocks.ToApacheCodec(req)) test.Assert(t, err != nil, err) }) t.Run("normal-input", func(t *testing.T) { req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) - err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, req) + err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, mocks.ToApacheCodec(req)) checkDecodeResult(t, err, req) }) } @@ -116,7 +116,7 @@ func TestUnmarshalThriftData(t *testing.T) { }) t.Run("BasicCodec", func(t *testing.T) { req := &mocks.MockReq{} - err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, req) + err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, mocks.ToApacheCodec(req)) checkDecodeResult(t, err, req) }) // FrugalCodec: in thrift_frugal_amd64_test.go: TestUnmarshalThriftDataFrugal diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 408bdfddda..d4c9f0eb9b 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -162,8 +162,8 @@ func BenchmarkNormalParallel(b *testing.B) { test.Assert(b, err == nil, err) // compare Req Arg - sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req - recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req + sendReq := mt.UnpackApacheCodec(sendMsg.Data()).(*mt.MockTestArgs).Req + recvReq := mt.UnpackApacheCodec(recvMsg.Data()).(*mt.MockTestArgs).Req test.Assert(b, sendReq.Msg == recvReq.Msg) test.Assert(b, len(sendReq.StrList) == len(recvReq.StrList)) test.Assert(b, len(sendReq.StrMap) == len(recvReq.StrMap)) @@ -280,7 +280,7 @@ func initSendMsg(tp transport.Protocol) remote.Message { _args.Req = prepareReq() ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) - msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Client) + msg := remote.NewMessage(mt.ToApacheCodec(&_args), svcInfo, ri, remote.Call, remote.Client) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) return msg } @@ -289,13 +289,13 @@ func initRecvMsg() remote.Message { var _args mt.MockTestArgs ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) - msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Server) + msg := remote.NewMessage(mt.ToApacheCodec(&_args), svcInfo, ri, remote.Call, remote.Server) return msg } func compare(t *testing.T, sendMsg, recvMsg remote.Message) { - sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req - recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req + sendReq := mt.UnpackApacheCodec(sendMsg.Data()).(*mt.MockTestArgs).Req + recvReq := mt.UnpackApacheCodec(recvMsg.Data()).(*mt.MockTestArgs).Req test.Assert(t, sendReq.Msg == recvReq.Msg) test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList)) test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap)) diff --git a/pkg/utils/thrift_test.go b/pkg/utils/thrift_test.go index d0139f75b5..18d7632b9b 100644 --- a/pkg/utils/thrift_test.go +++ b/pkg/utils/thrift_test.go @@ -39,12 +39,12 @@ func TestRPCCodec(t *testing.T) { args1.Req = req1 // encode - buf, err := rc.Encode("mockMethod", thrift.CALL, 100, args1) + buf, err := rc.Encode("mockMethod", thrift.CALL, 100, mt.ToApacheCodec(args1)) test.Assert(t, err == nil, err) var argsDecode1 mt.MockTestArgs // decode - method, seqID, err := rc.Decode(buf, &argsDecode1) + method, seqID, err := rc.Decode(buf, mt.ToApacheCodec(&argsDecode1)) test.Assert(t, err == nil) test.Assert(t, method == "mockMethod") @@ -65,12 +65,12 @@ func TestRPCCodec(t *testing.T) { args2 := mt.NewMockTestArgs() args2.Req = req2 // encode - buf, err = rc.Encode("mockMethod1", thrift.CALL, 101, args2) + buf, err = rc.Encode("mockMethod1", thrift.CALL, 101, mt.ToApacheCodec(args2)) test.Assert(t, err == nil, err) // decode var argsDecode2 mt.MockTestArgs - method, seqID, err = rc.Decode(buf, &argsDecode2) + method, seqID, err = rc.Decode(buf, mt.ToApacheCodec(&argsDecode2)) test.Assert(t, err == nil, err) test.Assert(t, method == "mockMethod1") @@ -95,11 +95,11 @@ func TestSerializer(t *testing.T) { args := mt.NewMockTestArgs() args.Req = req - b, err := rc.Serialize(args) + b, err := rc.Serialize(mt.ToApacheCodec(args)) test.Assert(t, err == nil, err) var args2 mt.MockTestArgs - err = rc.Deserialize(&args2, b) + err = rc.Deserialize(mt.ToApacheCodec(&args2), b) test.Assert(t, err == nil, err) test.Assert(t, args2.Req.Msg == req.Msg) From fdfdfe969510ba545956411af22c90f6c0ee4c3c Mon Sep 17 00:00:00 2001 From: Guangming Luo Date: Thu, 11 Jul 2024 15:48:17 +0800 Subject: [PATCH 11/70] chore: update CI version and readme community (#1431) --- .github/workflows/pr-check.yml | 6 +++--- .github/workflows/tests.yml | 10 +++++----- README.md | 3 +-- README_cn.md | 3 +-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index c156a606d0..6258952ab4 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -23,7 +23,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: "1.22" - uses: actions/cache@v3 with: @@ -39,7 +39,7 @@ jobs: reporter: github-pr-review # Report all results. filter_mode: nofilter - # Exit with 1 when it find at least one finding. + # Exit with 1 when it finds at least one finding. fail_on_error: true # Set staticcheck flags staticcheck_flags: -checks=inherit,-SA1029 @@ -51,7 +51,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: "1.22" - name: Golangci Lint # https://golangci-lint.run/ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 16f4b6c180..3521ae8fe8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: '1.20' + go-version: "1.20" - name: Unit Test run: go test -gcflags=-l -race -covermode=atomic -coverprofile=coverage.txt ./... - name: Scenario Tests @@ -31,14 +31,14 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: '1.18' + go-version: "1.22" - name: Benchmark run: go test -gcflags='all=-N -l' -bench=. -benchmem -run=none ./... compatibility-test: strategy: matrix: - go: [ 1.17, 1.18, 1.19.12, 1.20.7, 1.21, 1.22 ] + go: [ "1.17", "1.18", "1.19", "1.20", "1.21", "1.22" ] os: [ X64, ARM64 ] runs-on: ${{ matrix.os }} steps: @@ -57,7 +57,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: '1.19' + go-version: "1.22" - name: Prepare run: | go install github.com/cloudwego/thriftgo@main @@ -85,6 +85,6 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: "1.20" + go-version: "1.22" - name: Windows compatibility test run: go test -run=^$ ./... diff --git a/README.md b/README.md index 5921da32c5..c0b2b63876 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,6 @@ English | [中文](README_cn.md) [![ClosedIssue](https://img.shields.io/github/issues-closed/cloudwego/kitex)](https://github.com/cloudwego/kitex/issues?q=is%3Aissue+is%3Aclosed) ![Stars](https://img.shields.io/github/stars/cloudwego/kitex) ![Forks](https://img.shields.io/github/forks/cloudwego/kitex) -[![Slack](https://img.shields.io/badge/slack-join_chat-success.svg?logo=slack)](https://cloudwego.slack.com/join/shared_invite/zt-tmcbzewn-UjXMF3ZQsPhl7W3tEDZboA) Kitex [kaɪt'eks] is a **high-performance** and **strong-extensibility** Golang RPC framework that helps developers build microservices. If the performance and extensibility are the main concerns when you develop microservices, Kitex can be a good choice. @@ -114,7 +113,7 @@ Kitex is distributed under the [Apache License, version 2.0](https://github.com/ - Email: [conduct@cloudwego.io](conduct@cloudwego.io) - How to become a member: [COMMUNITY MEMBERSHIP](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) - Issues: [Issues](https://github.com/cloudwego/kitex/issues) -- Slack: Join our CloudWeGo community [Slack Channel](https://join.slack.com/t/cloudwego/shared_invite/zt-tmcbzewn-UjXMF3ZQsPhl7W3tEDZboA). +- Discord: Join community with [Discord Channel](https://discord.gg/jceZSE7DsW). - Lark: Scan the QR code below with [Lark](https://www.larksuite.com/zh_cn/download) to join our CloudWeGo/kitex user group. ![LarkGroup](images/lark_group.png) diff --git a/README_cn.md b/README_cn.md index 02271e5f5b..055681e386 100644 --- a/README_cn.md +++ b/README_cn.md @@ -10,7 +10,6 @@ [![ClosedIssue](https://img.shields.io/github/issues-closed/cloudwego/kitex)](https://github.com/cloudwego/kitex/issues?q=is%3Aissue+is%3Aclosed) ![Stars](https://img.shields.io/github/stars/cloudwego/kitex) ![Forks](https://img.shields.io/github/forks/cloudwego/kitex) -[![Slack](https://img.shields.io/badge/slack-join_chat-success.svg?logo=slack)](https://cloudwego.slack.com/join/shared_invite/zt-tmcbzewn-UjXMF3ZQsPhl7W3tEDZboA) Kitex[kaɪt'eks] 字节跳动内部的 Golang 微服务 RPC 框架,具有**高性能**、**强可扩展**的特点,在字节内部已广泛使用。如今越来越多的微服务选择使用 Golang,如果对微服务性能有要求,又希望定制扩展融入自己的治理体系,Kitex 会是一个不错的选择。 @@ -112,7 +111,7 @@ Kitex 基于[Apache License 2.0](LICENSE) 许可证,其依赖的三方组件 - Email: conduct@cloudwego.io - 如何成为 member: [COMMUNITY MEMBERSHIP](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) - Issues: [Issues](https://github.com/cloudwego/kitex/issues) -- Slack: 加入我们的 [Slack 频道](https://join.slack.com/t/cloudwego/shared_invite/zt-tmcbzewn-UjXMF3ZQsPhl7W3tEDZboA) +- Discord: 加入我们的 [Discord 频道](https://discord.gg/jceZSE7DsW) - 飞书用户群([注册飞书](https://www.feishu.cn/)后扫码进群) ![LarkGroup](images/lark_group_cn.png) From ad6559a298f53e70120921b24dc9c02a5d1f8dca Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Sat, 13 Jul 2024 02:49:53 +0800 Subject: [PATCH 12/70] refactor: new generic interface without thrift apache (#1434) --- internal/mocks/thrift/utils.go | 27 ++++----------- pkg/remote/codec/thrift/binary_protocol.go | 2 +- pkg/remote/codec/thrift/thrift.go | 23 +++++++++---- pkg/remote/codec/thrift/thrift_data.go | 30 +++++++++------- pkg/remote/trans/netpoll/bytebuf.go | 34 +++---------------- .../trans/netpoll/http_client_handler_test.go | 9 +++-- 6 files changed, 53 insertions(+), 72 deletions(-) diff --git a/internal/mocks/thrift/utils.go b/internal/mocks/thrift/utils.go index 6c4cabe874..13a79862d3 100644 --- a/internal/mocks/thrift/utils.go +++ b/internal/mocks/thrift/utils.go @@ -17,6 +17,7 @@ package thrift import ( + "errors" "io" "github.com/cloudwego/kitex/pkg/protocol/bthrift" @@ -32,37 +33,21 @@ type ApacheCodecAdapter struct { func (p ApacheCodecAdapter) Write(tp thrift.TProtocol) error { b := make([]byte, p.p.BLength()) b = b[:p.p.FastWriteNocopy(b, nil)] - trans := tp.Transport() - if t, ok := trans.(remoteByteBuffer); ok { - // remote.ByteBuffer not always implement io.Writer ... - // can only use WriteBinary - _, err := t.WriteBinary(b) - return err - } _, err := tp.Transport().Write(b) return err } -type remoteByteBuffer interface { - ReadableLen() (n int) - Next(n int) (p []byte, err error) - WriteBinary(b []byte) (n int, err error) -} - // Read implements thrift.TStruct func (p ApacheCodecAdapter) Read(tp thrift.TProtocol) error { var err error var b []byte trans := tp.Transport() - if t, ok := trans.(remoteByteBuffer); ok { - // remote.ByteBuffer not always implement io.Reader ... - // can only use Next() - b, err = t.Next(t.ReadableLen()) - } else { - n := trans.RemainingBytes() - b = make([]byte, n) - _, err = io.ReadFull(trans, b) + n := trans.RemainingBytes() + if int64(n) < 0 { + return errors.New("unknown buffer len") } + b = make([]byte, n) + _, err = io.ReadFull(trans, b) if err == nil { _, err = p.p.FastRead(b) } diff --git a/pkg/remote/codec/thrift/binary_protocol.go b/pkg/remote/codec/thrift/binary_protocol.go index bc25cf1929..b2775b1c2f 100644 --- a/pkg/remote/codec/thrift/binary_protocol.go +++ b/pkg/remote/codec/thrift/binary_protocol.go @@ -506,7 +506,7 @@ func (ttransportByteBuffer) Close() error { panic("not func (ttransportByteBuffer) Flush(ctx context.Context) (err error) { panic("not implemented") } func (ttransportByteBuffer) IsOpen() bool { panic("not implemented") } func (ttransportByteBuffer) Open() error { panic("not implemented") } -func (ttransportByteBuffer) RemainingBytes() uint64 { panic("not implemented") } +func (p ttransportByteBuffer) RemainingBytes() uint64 { return uint64(p.ReadableLen()) } // Transport ... func (p *BinaryProtocol) Transport() thrift.TTransport { diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 9a9bab0b02..47fb17efef 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "io" "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" @@ -239,11 +240,6 @@ func (c thriftCodec) Name() string { return serviceinfo.Thrift.String() } -// MessageWriterWithMethodWithContext write to thrift.TProtocol -type MessageWriterWithMethodWithContext interface { - Write(ctx context.Context, method string, oprot thrift.TProtocol) error -} - // MessageWriter write to thrift.TProtocol type MessageWriter interface { Write(oprot thrift.TProtocol) error @@ -254,9 +250,24 @@ type MessageReader interface { Read(oprot thrift.TProtocol) error } +type genericWriter interface { // used by pkg/generic + Write(ctx context.Context, method string, w io.Writer) error +} + +type genericReader interface { // used by pkg/generic + Read(ctx context.Context, method string, dataLen int, r io.Reader) error +} + +// MessageWriterWithMethodWithContext write to thrift.TProtocol +// TODO(marina.sakai): remove it after we use the new genericWriter interface +type MessageWriterWithMethodWithContext interface { + Write(ctx context.Context, method string, oprot thrift.TProtocol) error +} + // MessageReaderWithMethodWithContext read from thrift.TProtocol with method +// TODO(marina.sakai): remove it after we use the new genericReader interface type MessageReaderWithMethodWithContext interface { - Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error + Read(ctx context.Context, method string, dataLen int, iprot thrift.TProtocol) error } // ThriftMsgFastCodec ... diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 7d42784d7a..b15028b3dc 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -83,6 +83,7 @@ func verifyMarshalBasicThriftDataType(data interface{}) error { switch data.(type) { case MessageWriter: case MessageWriterWithMethodWithContext: + case genericWriter: default: return errEncodeMismatchMsgType } @@ -92,18 +93,20 @@ func verifyMarshalBasicThriftDataType(data interface{}) error { // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) // It uses the old thrift way which is much slower than FastCodec and Frugal func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}, method string, rpcRole remote.RPCRole) error { + var err error switch msg := data.(type) { case MessageWriter: - if err := msg.Write(tProt); err != nil { - return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) - } + err = msg.Write(tProt) case MessageWriterWithMethodWithContext: - if err := msg.Write(ctx, method, tProt); err != nil { - return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) - } + err = msg.Write(ctx, method, tProt) + case genericWriter: + err = msg.Write(ctx, method, tProt.Transport()) default: return errEncodeMismatchMsgType } + if err != nil { + return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) + } return nil } @@ -227,6 +230,7 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error { switch data.(type) { case MessageReader: case MessageReaderWithMethodWithContext: + case genericReader: default: return errDecodeMismatchMsgType } @@ -238,17 +242,17 @@ func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method s var err error switch t := data.(type) { case MessageReader: - if err = t.Read(tProt); err != nil { - return remote.NewTransError(remote.ProtocolError, err) - } + err = t.Read(tProt) case MessageReaderWithMethodWithContext: - // methodName is necessary for generic calls to methodInfo from serviceInfo - if err = t.Read(ctx, method, dataLen, tProt); err != nil { - return remote.NewTransError(remote.ProtocolError, err) - } + err = t.Read(ctx, method, dataLen, tProt) + case genericReader: + err = t.Read(ctx, method, dataLen, tProt.Transport()) default: return errDecodeMismatchMsgType } + if err != nil { + return remote.NewTransError(remote.ProtocolError, err) + } return nil } diff --git a/pkg/remote/trans/netpoll/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf.go index 736060da97..80612c7e97 100644 --- a/pkg/remote/trans/netpoll/bytebuf.go +++ b/pkg/remote/trans/netpoll/bytebuf.go @@ -18,7 +18,6 @@ package netpoll import ( "errors" - "io" "sync" "github.com/cloudwego/netpoll" @@ -36,11 +35,6 @@ func init() { func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer { bytebuf := bytebufPool.Get().(*netpollByteBuffer) bytebuf.reader = r - // TODO(wangtieju): fix me when netpoll support netpoll.Reader - // and LinkBuffer not support io.Reader, type assertion would fail when r is from NewBuffer - if ir, ok := r.(io.Reader); ok { - bytebuf.ioReader = ir - } bytebuf.status = remote.BitReadable bytebuf.readSize = 0 return bytebuf @@ -50,11 +44,6 @@ func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer { func NewWriterByteBuffer(w netpoll.Writer) remote.ByteBuffer { bytebuf := bytebufPool.Get().(*netpollByteBuffer) bytebuf.writer = w - // TODO(wangtieju): fix me when netpoll support netpoll.Writer - // and LinkBuffer not support io.Reader, type assertion would fail when w is from NewBuffer - if iw, ok := w.(io.Writer); ok { - bytebuf.ioWriter = iw - } bytebuf.status = remote.BitWritable return bytebuf } @@ -64,12 +53,6 @@ func NewReaderWriterByteBuffer(rw netpoll.ReadWriter) remote.ByteBuffer { bytebuf := bytebufPool.Get().(*netpollByteBuffer) bytebuf.writer = rw bytebuf.reader = rw - // TODO(wangtieju): fix me when netpoll support netpoll.ReadWriter - // and LinkBuffer not support io.ReadWriter, type assertion would fail when rw is from NewBuffer - if irw, ok := rw.(io.ReadWriter); ok { - bytebuf.ioReader = irw - bytebuf.ioWriter = irw - } bytebuf.status = remote.BitWritable | remote.BitReadable return bytebuf } @@ -81,8 +64,6 @@ func newNetpollByteBuffer() interface{} { type netpollByteBuffer struct { writer netpoll.Writer reader netpoll.Reader - ioReader io.Reader - ioWriter io.Writer status int readSize int } @@ -130,10 +111,9 @@ func (b *netpollByteBuffer) Read(p []byte) (n int, err error) { if b.status&remote.BitReadable == 0 { return -1, errors.New("unreadable buffer, cannot support Read") } - if b.ioReader != nil { - return b.ioReader.Read(p) - } - return -1, errors.New("ioReader is nil") + rb, err := b.reader.Next(len(p)) + b.readSize += len(rb) + return copy(p, rb), err } // ReadString is a more efficient way to read string than Next. @@ -188,10 +168,8 @@ func (b *netpollByteBuffer) Write(p []byte) (n int, err error) { if b.status&remote.BitWritable == 0 { return -1, errors.New("unwritable buffer, cannot support Write") } - if b.ioWriter != nil { - return b.ioWriter.Write(p) - } - return -1, errors.New("ioWriter is nil") + wb, err := b.writer.Malloc(len(p)) + return copy(wb, p), err } // WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte. @@ -268,8 +246,6 @@ func (b *netpollByteBuffer) Release(e error) (err error) { func (b *netpollByteBuffer) zero() { b.writer = nil b.reader = nil - b.ioReader = nil - b.ioWriter = nil b.status = 0 b.readSize = 0 } diff --git a/pkg/remote/trans/netpoll/http_client_handler_test.go b/pkg/remote/trans/netpoll/http_client_handler_test.go index 403df6a17e..df4f90a832 100644 --- a/pkg/remote/trans/netpoll/http_client_handler_test.go +++ b/pkg/remote/trans/netpoll/http_client_handler_test.go @@ -17,6 +17,7 @@ package netpoll import ( + "bytes" "context" "net/http" "strings" @@ -56,7 +57,11 @@ func init() { // TestHTTPWrite test http_client_handler Write return err func TestHTTPWrite(t *testing.T) { // 1. prepare mock data - conn := &MockNetpollConn{} + conn := &MockNetpollConn{ + WriterFunc: func() netpoll.Writer { + return netpoll.NewWriter(&bytes.Buffer{}) + }, + } rwTimeout := time.Second cfg := rpcinfo.NewRPCConfig() rpcinfo.AsMutableRPCConfig(cfg).SetReadWriteTimeout(rwTimeout) @@ -70,8 +75,8 @@ func TestHTTPWrite(t *testing.T) { // 2. test ctx, err := httpCilTransHdlr.Write(ctx, conn, msg) // check ctx/err not nil + test.Assert(t, err == nil, err) test.Assert(t, ctx != nil) - test.Assert(t, err != nil) } // TestHTTPRead test http_client_handler Read return err From 2ef26ea7688bd5aabb2fea7fedabdb9587b14957 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:20:12 +0800 Subject: [PATCH 13/70] fix(generic): fix payload length check for http generic (#1442) --- go.mod | 6 +----- pkg/generic/thrift/http.go | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index b6db0d6bbf..c941c1eca1 100644 --- a/go.mod +++ b/go.mod @@ -34,20 +34,16 @@ require ( github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/dlclark/regexp2 v1.11.0 // indirect github.com/fatih/structtag v1.2.0 // indirect github.com/golang/protobuf v1.5.2 // indirect - github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/iancoleman/strcase v0.2.0 // indirect - github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oleiade/lane v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect - github.com/smartystreets/goconvey v1.6.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/pkg/generic/thrift/http.go b/pkg/generic/thrift/http.go index b14e674981..1311fdb7ec 100644 --- a/pkg/generic/thrift/http.go +++ b/pkg/generic/thrift/http.go @@ -133,7 +133,7 @@ func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options) { // Read ... func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { // fallback logic - if !r.dynamicgoEnabled { + if !r.dynamicgoEnabled || dataLen == 0 { return r.originalRead(ctx, method, in) } tProt, ok := in.(*cthrift.BinaryProtocol) From 68599a0f21b8be2eb95e6b941570a8d6aee0a4a8 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 25 Jul 2024 20:20:22 +0800 Subject: [PATCH 14/70] chore(ci): disable cache for lint and staticchecks (#1451) --- .github/workflows/pr-check.yml | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index c156a606d0..5fa0fe214d 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -19,18 +19,15 @@ jobs: staticcheck: runs-on: [ self-hosted, X64 ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 - with: - go-version: 1.19 - - - uses: actions/cache@v3 + uses: actions/setup-go@v5 with: - path: ~/go/pkg/mod - key: reviewdog-${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - reviewdog-${{ runner.os }}-go- + go-version: stable + # For self-hosted, the cache path is shared across projects + # and it works well without the cache of github actions + # Enable it if we're going to use Github only + cache: false - uses: reviewdog/action-staticcheck@v1 with: @@ -47,14 +44,18 @@ jobs: lint: runs-on: [ self-hosted, X64 ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: 1.19 + go-version: stable + # for self-hosted, the cache path is shared across projects + # and it works well without the cache of github actions + # Enable it if we're going to use Github only + cache: false - name: Golangci Lint # https://golangci-lint.run/ - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest From bbc66875a9f282a3a6ed2d7f638935a0e91ee2d0 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 26 Jul 2024 10:41:42 +0800 Subject: [PATCH 15/70] refactor(test): perf optimize and log loc correct (#1455) --- internal/test/assert.go | 51 ++++++++----- internal/test/assert_test.go | 105 ++++++++++++++++++++++++++ pkg/remote/connpool/long_pool_test.go | 11 ++- 3 files changed, 141 insertions(+), 26 deletions(-) create mode 100644 internal/test/assert_test.go diff --git a/internal/test/assert.go b/internal/test/assert.go index aca379ab1f..71f3fe6d52 100644 --- a/internal/test/assert.go +++ b/internal/test/assert.go @@ -27,10 +27,10 @@ type testingTB interface { // Assert asserts cond is true, otherwise fails the test. func Assert(t testingTB, cond bool, val ...interface{}) { - t.Helper() if !cond { + t.Helper() if len(val) > 0 { - val = append([]interface{}{"assertion failed:"}, val...) + val = append([]interface{}{"assertion failed: "}, val...) t.Fatal(val...) } else { t.Fatal("assertion failed") @@ -40,42 +40,53 @@ func Assert(t testingTB, cond bool, val ...interface{}) { // Assertf asserts cond is true, otherwise fails the test. func Assertf(t testingTB, cond bool, format string, val ...interface{}) { - t.Helper() if !cond { + t.Helper() t.Fatalf(format, val...) } } // DeepEqual asserts a and b are deep equal, otherwise fails the test. func DeepEqual(t testingTB, a, b interface{}) { - t.Helper() if !reflect.DeepEqual(a, b) { + t.Helper() t.Fatalf("assertion failed: %v != %v", a, b) } } // Panic asserts fn should panic and recover it, otherwise fails the test. func Panic(t testingTB, fn func()) { - t.Helper() - defer func() { - if err := recover(); err == nil { - t.Fatal("assertion failed: did not panic") - } + hasPanic := false + func() { + defer func() { + if err := recover(); err != nil { + hasPanic = true + } + }() + fn() }() - fn() + if !hasPanic { + t.Helper() + t.Fatal("assertion failed: did not panic") + } } // PanicAt asserts fn should panic and recover it, otherwise fails the test. The expect function can be provided to do further examination of the error. func PanicAt(t testingTB, fn func(), expect func(err interface{}) bool) { - t.Helper() - defer func() { - if err := recover(); err == nil { - t.Fatal("assertion failed: did not panic") - } else { - if expect != nil && !expect(err) { - t.Fatal("assertion failed: panic but not expected") - } - } + var err interface{} + func() { + defer func() { + err = recover() + }() + fn() }() - fn() + if err == nil { + t.Helper() + t.Fatal("assertion failed: did not panic") + return + } + if expect != nil && !expect(err) { + t.Helper() + t.Fatal("assertion failed: panic but not expected") + } } diff --git a/internal/test/assert_test.go b/internal/test/assert_test.go new file mode 100644 index 0000000000..f60b3f1664 --- /dev/null +++ b/internal/test/assert_test.go @@ -0,0 +1,105 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "fmt" + "testing" +) + +type mockTesting struct { + t *testing.T + + expect0 string + expect1 string + + helper bool +} + +func (m *mockTesting) Reset() { + m.expect0 = "" + m.expect1 = "" + m.helper = false +} + +func (m *mockTesting) ExpectFatal(args ...interface{}) { + m.expect0 = fmt.Sprint(args...) +} + +func (m *mockTesting) ExpectFatalf(format string, args ...interface{}) { + m.expect1 = fmt.Sprintf(format, args...) +} + +func (m *mockTesting) Fatal(args ...interface{}) { + t := m.t + t.Helper() + if !m.helper { + t.Fatal("need to call Helper before calling Fatal") + } + if s := fmt.Sprint(args...); s != m.expect0 { + t.Fatalf("got %q expect %q", s, m.expect0) + } +} + +func (m *mockTesting) Fatalf(format string, args ...interface{}) { + t := m.t + t.Helper() + if !m.helper { + t.Fatal("need to call Helper before calling Fatalf") + } + if s := fmt.Sprintf(format, args...); s != m.expect1 { + t.Fatalf("got %q expect %q", s, m.expect1) + } +} + +func (m *mockTesting) Helper() { m.helper = true } + +func TestAssert(t *testing.T) { + m := &mockTesting{t: t} + + m.Reset() + m.ExpectFatal("assertion failed") + Assert(m, false) + + m.Reset() + m.ExpectFatal("assertion failed: hello") + Assert(m, false, "hello") + + m.Reset() + m.ExpectFatalf("assert: %s", "hello") + Assertf(m, false, "assert: %s", "hello") + + m.Reset() + m.ExpectFatalf("assertion failed: 1 != 2") + DeepEqual(m, 1, 2) + + m.Reset() + m.ExpectFatal("") + Panic(m, func() { panic("hello") }) + + m.Reset() + m.ExpectFatal("assertion failed: did not panic") + Panic(m, func() {}) + + m.Reset() + m.ExpectFatal("assertion failed: did not panic") + PanicAt(m, func() {}, func(err interface{}) bool { return true }) + + m.Reset() + m.ExpectFatal("assertion failed: panic but not expected") + PanicAt(m, func() { panic("hello") }, func(err interface{}) bool { return false }) +} diff --git a/pkg/remote/connpool/long_pool_test.go b/pkg/remote/connpool/long_pool_test.go index e086ab7365..34472aa725 100644 --- a/pkg/remote/connpool/long_pool_test.go +++ b/pkg/remote/connpool/long_pool_test.go @@ -567,7 +567,7 @@ func TestLongConnPoolCloseOnIdleTimeout(t *testing.T) { p := newLongPoolForTest(0, 2, 5, idleTime) defer p.Close() - var closed bool + var closed uint32 // use atomic to fix data race issue d := mocksremote.NewMockDialer(ctrl) d.EXPECT().DialTimeout(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(network, address string, timeout time.Duration) (net.Conn, error) { na := utils.NewNetAddr(network, address) @@ -575,10 +575,9 @@ func TestLongConnPoolCloseOnIdleTimeout(t *testing.T) { conn.EXPECT().IsActive().Return(true).AnyTimes() conn.EXPECT().RemoteAddr().Return(na).AnyTimes() conn.EXPECT().Close().DoAndReturn(func() error { - if closed { + if !atomic.CompareAndSwapUint32(&closed, 0, 1) { return errors.New("connection already closed") } - closed = true return nil }).AnyTimes() return conn, nil @@ -589,18 +588,18 @@ func TestLongConnPoolCloseOnIdleTimeout(t *testing.T) { c, err := p.Get(context.TODO(), "tcp", addr, opt) test.Assert(t, err == nil) - test.Assert(t, !closed) + test.Assert(t, atomic.LoadUint32(&closed) == 0) err = p.Put(c) test.Assert(t, err == nil) - test.Assert(t, !closed) + test.Assert(t, atomic.LoadUint32(&closed) == 0) time.Sleep(idleTime * 3) c2, err := p.Get(context.TODO(), "tcp", addr, opt) test.Assert(t, err == nil) test.Assert(t, c != c2) - test.Assert(t, closed) // the first connection should be closed + test.Assert(t, atomic.LoadUint32(&closed) == 1) // the first connection should be closed } func TestLongConnPoolCloseOnClean(t *testing.T) { From fe79ac4448751d1ff45c13f4b2863c2ecf61f8d6 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 26 Jul 2024 11:49:47 +0800 Subject: [PATCH 16/70] chore(ci): speed up multiple ci processes 8min -> 1min (#1454) * rm unused codeconv, it didn't work as expected due to quota and user experience * don't use cache for self-hosted runners * cache for github hosted runners --- .codecov.yml | 4 --- .github/workflows/tests.yml | 29 +++++++++----------- pkg/loadbalance/weighted_balancer_test.go | 4 +-- pkg/loadbalance/weighted_round_robin_test.go | 3 +- pkg/mem/span_test.go | 4 +-- 5 files changed, 19 insertions(+), 25 deletions(-) delete mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index 3748260a10..0000000000 --- a/.codecov.yml +++ /dev/null @@ -1,4 +0,0 @@ -ignore: - - "images/.*" - - "tool" - - "internal/mocks" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 16f4b6c180..ae4eaf6bdd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,13 +6,11 @@ jobs: unit-scenario-test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: '1.20' - - name: Unit Test - run: go test -gcflags=-l -race -covermode=atomic -coverprofile=coverage.txt ./... - name: Scenario Tests run: | cd .. @@ -21,19 +19,17 @@ jobs: cd kitex-tests ./run.sh ${{github.workspace}} cd ${{github.workspace}} - - name: Codecov - run: bash <(curl -s https://codecov.io/bash) benchmark-test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: '1.18' - name: Benchmark - run: go test -gcflags='all=-N -l' -bench=. -benchmem -run=none ./... + run: go test -bench=. -benchmem -run=none ./... compatibility-test: strategy: @@ -42,20 +38,21 @@ jobs: os: [ X64, ARM64 ] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} + cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -gcflags=-l -race -covermode=atomic ./... + run: go test -race -covermode=atomic ./... codegen-test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: '1.19' - name: Prepare @@ -81,9 +78,9 @@ jobs: windows-test: runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: "1.20" - name: Windows compatibility test diff --git a/pkg/loadbalance/weighted_balancer_test.go b/pkg/loadbalance/weighted_balancer_test.go index 57f74b87e6..6167a09c9f 100644 --- a/pkg/loadbalance/weighted_balancer_test.go +++ b/pkg/loadbalance/weighted_balancer_test.go @@ -139,7 +139,7 @@ func TestWeightedPicker_Next(t *testing.T) { weight := ins.Weight() weightSum += weight } - n := 10000000 + n := 1000000 pickedStat := map[int]int{} for i := 0; i < n; i++ { picker := balancer.GetPicker(discovery.Result{ @@ -160,7 +160,7 @@ func TestWeightedPicker_Next(t *testing.T) { expect := float64(weight) / float64(weightSum) * float64(n) actual := float64(pickedStat[weight]) delta := math.Abs(expect - actual) - test.Assertf(t, delta/expect < 0.01, "delta(%f)/expect(%f) = %f", delta, expect, delta/expect) + test.Assertf(t, delta/expect < 0.05, "delta(%f)/expect(%f) = %f", delta, expect, delta/expect) } // weightSum = 0 diff --git a/pkg/loadbalance/weighted_round_robin_test.go b/pkg/loadbalance/weighted_round_robin_test.go index 6a63eb8e04..2f43dfb3b0 100644 --- a/pkg/loadbalance/weighted_round_robin_test.go +++ b/pkg/loadbalance/weighted_round_robin_test.go @@ -18,6 +18,7 @@ package loadbalance import ( "context" + "runtime" "sync" "testing" @@ -93,7 +94,7 @@ func TestWeightedRoundRobinPickerLargeInstances(t *testing.T) { insList = append(insList, discovery.NewInstance("tcp", "nbalance", 100, nil)) picker := newWeightedRoundRobinPicker(insList) - concurrency := wrrVNodesBatchSize * 2 + concurrency := 2 + runtime.GOMAXPROCS(0)*2 round := len(insList) var wg sync.WaitGroup for c := 0; c < concurrency; c++ { diff --git a/pkg/mem/span_test.go b/pkg/mem/span_test.go index 1fb4f3d88f..d24d581ba8 100644 --- a/pkg/mem/span_test.go +++ b/pkg/mem/span_test.go @@ -60,7 +60,7 @@ func TestSpan(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for i := 0; i < 1024; i++ { + for i := 0; i < 128; i++ { buf := []byte("123") test.DeepEqual(t, bc.Copy(buf), buf) @@ -92,7 +92,7 @@ func TestSpanCache(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for i := 0; i < 1024; i++ { + for i := 0; i < 128; i++ { buf := []byte("123") test.DeepEqual(t, bc.Copy(buf), buf) From 095cf94e8009d97d8601893475a26effb8844e46 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 26 Jul 2024 14:06:51 +0800 Subject: [PATCH 17/70] chore: pick and fix conflict commits from develop branch (#1457) Co-authored-by: Jayant --- .codecov.yml | 4 -- .github/workflows/pr-check.yml | 29 ++++++++------- .github/workflows/tests.yml | 31 +++++++--------- pkg/loadbalance/weighted_balancer_test.go | 4 +- pkg/loadbalance/weighted_round_robin_test.go | 3 +- pkg/mem/span_test.go | 4 +- pkg/protocol/bthrift/binary.go | 30 +++++---------- pkg/protocol/bthrift/binary_test.go | 39 -------------------- pkg/protocol/bthrift/interface.go | 5 --- 9 files changed, 45 insertions(+), 104 deletions(-) delete mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index 3748260a10..0000000000 --- a/.codecov.yml +++ /dev/null @@ -1,4 +0,0 @@ -ignore: - - "images/.*" - - "tool" - - "internal/mocks" diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 6258952ab4..eca1ceb026 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -19,18 +19,15 @@ jobs: staticcheck: runs-on: [ self-hosted, X64 ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 - with: - go-version: "1.22" - - - uses: actions/cache@v3 + uses: actions/setup-go@v5 with: - path: ~/go/pkg/mod - key: reviewdog-${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - reviewdog-${{ runner.os }}-go- + go-version: stable + # For self-hosted, the cache path is shared across projects + # and it works well without the cache of github actions + # Enable it if we're going to use Github only + cache: false - uses: reviewdog/action-staticcheck@v1 with: @@ -47,14 +44,18 @@ jobs: lint: runs-on: [ self-hosted, X64 ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: "1.22" + go-version: stable + # for self-hosted, the cache path is shared across projects + # and it works well without the cache of github actions + # Enable it if we're going to use Github only + cache: false - name: Golangci Lint # https://golangci-lint.run/ - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3521ae8fe8..fbe78933c2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,13 +6,11 @@ jobs: unit-scenario-test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: "1.20" - - name: Unit Test - run: go test -gcflags=-l -race -covermode=atomic -coverprofile=coverage.txt ./... + go-version: '1.20' - name: Scenario Tests run: | cd .. @@ -21,19 +19,17 @@ jobs: cd kitex-tests ./run.sh ${{github.workspace}} cd ${{github.workspace}} - - name: Codecov - run: bash <(curl -s https://codecov.io/bash) benchmark-test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: "1.22" - name: Benchmark - run: go test -gcflags='all=-N -l' -bench=. -benchmem -run=none ./... + run: go test -bench=. -benchmem -run=none ./... compatibility-test: strategy: @@ -42,20 +38,21 @@ jobs: os: [ X64, ARM64 ] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} + cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -gcflags=-l -race -covermode=atomic ./... + run: go test -race -covermode=atomic ./... codegen-test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: "1.22" - name: Prepare @@ -81,9 +78,9 @@ jobs: windows-test: runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: "1.22" - name: Windows compatibility test diff --git a/pkg/loadbalance/weighted_balancer_test.go b/pkg/loadbalance/weighted_balancer_test.go index 57f74b87e6..6167a09c9f 100644 --- a/pkg/loadbalance/weighted_balancer_test.go +++ b/pkg/loadbalance/weighted_balancer_test.go @@ -139,7 +139,7 @@ func TestWeightedPicker_Next(t *testing.T) { weight := ins.Weight() weightSum += weight } - n := 10000000 + n := 1000000 pickedStat := map[int]int{} for i := 0; i < n; i++ { picker := balancer.GetPicker(discovery.Result{ @@ -160,7 +160,7 @@ func TestWeightedPicker_Next(t *testing.T) { expect := float64(weight) / float64(weightSum) * float64(n) actual := float64(pickedStat[weight]) delta := math.Abs(expect - actual) - test.Assertf(t, delta/expect < 0.01, "delta(%f)/expect(%f) = %f", delta, expect, delta/expect) + test.Assertf(t, delta/expect < 0.05, "delta(%f)/expect(%f) = %f", delta, expect, delta/expect) } // weightSum = 0 diff --git a/pkg/loadbalance/weighted_round_robin_test.go b/pkg/loadbalance/weighted_round_robin_test.go index 6a63eb8e04..2f43dfb3b0 100644 --- a/pkg/loadbalance/weighted_round_robin_test.go +++ b/pkg/loadbalance/weighted_round_robin_test.go @@ -18,6 +18,7 @@ package loadbalance import ( "context" + "runtime" "sync" "testing" @@ -93,7 +94,7 @@ func TestWeightedRoundRobinPickerLargeInstances(t *testing.T) { insList = append(insList, discovery.NewInstance("tcp", "nbalance", 100, nil)) picker := newWeightedRoundRobinPicker(insList) - concurrency := wrrVNodesBatchSize * 2 + concurrency := 2 + runtime.GOMAXPROCS(0)*2 round := len(insList) var wg sync.WaitGroup for c := 0; c < concurrency; c++ { diff --git a/pkg/mem/span_test.go b/pkg/mem/span_test.go index 1fb4f3d88f..d24d581ba8 100644 --- a/pkg/mem/span_test.go +++ b/pkg/mem/span_test.go @@ -60,7 +60,7 @@ func TestSpan(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for i := 0; i < 1024; i++ { + for i := 0; i < 128; i++ { buf := []byte("123") test.DeepEqual(t, bc.Copy(buf), buf) @@ -92,7 +92,7 @@ func TestSpanCache(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for i := 0; i < 1024; i++ { + for i := 0; i < 128; i++ { buf := []byte("123") test.DeepEqual(t, bc.Copy(buf), buf) diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index e10dc2ecf8..8db034e1dd 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -30,28 +30,19 @@ import ( var ( // Binary protocol for bthrift. - Binary binaryProtocol - _ BTProtocol = binaryProtocol{} + Binary binaryProtocol + _ BTProtocol = binaryProtocol{} + spanCache = mem.NewSpanCache(1024 * 1024) + spanCacheEnable bool = false ) -var allocator Allocator - const binaryInplaceThreshold = 4096 // 4k type binaryProtocol struct{} // SetSpanCache enable/disable binary protocol bytes/string allocator func SetSpanCache(enable bool) { - if enable { - SetAllocator(mem.NewSpanCache(1024 * 1024)) - } else { - SetAllocator(nil) - } -} - -// SetAllocator set binary protocol bytes/string allocator. -func SetAllocator(alloc Allocator) { - allocator = alloc + spanCacheEnable = enable } func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID thrift.TMessageType, seqid int32) int { @@ -479,9 +470,9 @@ func (binaryProtocol) ReadString(buf []byte) (value string, length int, err erro if size < 0 || int(size) > len(buf) { return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadString] the string size greater than buf length") } - alloc := allocator - if alloc != nil { - value = sliceByteToString(alloc.Copy(buf[length : length+int(size)])) + if spanCacheEnable { + data := spanCache.Copy(buf[length : length+int(size)]) + value = sliceByteToString(data) } else { value = string(buf[length : length+int(size)]) } @@ -500,9 +491,8 @@ func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err erro if size < 0 || size > len(buf) { return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length") } - alloc := allocator - if alloc != nil { - value = alloc.Copy(buf[length : length+size]) + if spanCacheEnable { + value = spanCache.Copy(buf[length : length+size]) } else { value = make([]byte, size) copy(value, buf[length:length+size]) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 8e395bb5cd..a0754bcd55 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -291,24 +291,6 @@ func TestWriteAndReadString(t *testing.T) { test.Assert(t, v == "kitex") } -// TestWriteAndReadStringWithSpanCache test binary WriteString and ReadString with spanCache allocator -func TestWriteAndReadStringWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - wn := Binary.WriteString(buf, "kitex") - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadString(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - test.Assert(t, v == "kitex") - SetSpanCache(false) -} - // TestWriteAndReadBinary test binary WriteBinary and ReadBinary func TestWriteAndReadBinary(t *testing.T) { buf := make([]byte, 128) @@ -328,27 +310,6 @@ func TestWriteAndReadBinary(t *testing.T) { } } -// TestWriteAndReadBinaryWithSpanCache test binary WriteBinary and ReadBinary with spanCache allocator -func TestWriteAndReadBinaryWithSpanCache(t *testing.T) { - buf := make([]byte, 128) - exceptWs := "000000056b69746578" - exceptSize := 9 - val := []byte("kitex") - wn := Binary.WriteBinary(buf, val) - ws := fmt.Sprintf("%x", buf[:wn]) - test.Assert(t, wn == exceptSize, wn, exceptSize) - test.Assert(t, ws == exceptWs, ws, exceptWs) - - SetSpanCache(true) - v, length, err := Binary.ReadBinary(buf) - test.Assert(t, nil == err) - test.Assert(t, exceptSize == length) - for i := 0; i < len(v); i++ { - test.Assert(t, val[i] == v[i]) - } - SetSpanCache(false) -} - // TestWriteStringNocopy test binary WriteStringNocopy with small content func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index 62bbb23dc6..06965813bc 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -103,8 +103,3 @@ type BTProtocol interface { ReadBinary(buf []byte) (value []byte, length int, err error) Skip(buf []byte, fieldType thrift.TType) (length int, err error) } - -type Allocator interface { - Make(n int) []byte - Copy(buf []byte) (p []byte) -} From dd79a788a8d93de6d9d39bd158680f98fcb2be0d Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 26 Jul 2024 15:21:51 +0800 Subject: [PATCH 18/70] chore(ci): optimized bench tests. it takes <1m now (#1461) --- .github/workflows/tests.yml | 4 ++- pkg/loadbalance/consist.go | 10 ------- pkg/loadbalance/consist_test.go | 48 ++++++++++++++++----------------- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fbe78933c2..9f2d820fbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,9 @@ jobs: with: go-version: "1.22" - name: Benchmark - run: go test -bench=. -benchmem -run=none ./... + # we only use this CI to verify bench code works + # setting benchtime=100ms is saving our time... + run: go test -bench=. -benchmem -run=none ./... -benchtime=100ms compatibility-test: strategy: diff --git a/pkg/loadbalance/consist.go b/pkg/loadbalance/consist.go index df56102e93..52c8965a0c 100644 --- a/pkg/loadbalance/consist.go +++ b/pkg/loadbalance/consist.go @@ -29,16 +29,6 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) -/* - Benchmark results with different instance numbers when weight = 10 and virtual factor = 100: - BenchmarkNewConsistPicker_NoCache/10ins-16 6565 160670 ns/op 164750 B/op 5 allocs/op - BenchmarkNewConsistPicker_NoCache/100ins-16 571 1914666 ns/op 1611803 B/op 6 allocs/op - BenchmarkNewConsistPicker_NoCache/1000ins-16 45 23485916 ns/op 16067720 B/op 10 allocs/op - BenchmarkNewConsistPicker_NoCache/10000ins-16 4 251160920 ns/op 160405632 B/op 41 allocs/op - - When there's 10000 instances which weight = 10 and virtual factor = 100, the time need to build is 251 ms. -*/ - /* type hints for sync.Map: consistBalancer -> sync.Map[entry.CacheKey]*consistInfo diff --git a/pkg/loadbalance/consist_test.go b/pkg/loadbalance/consist_test.go index 816583babe..2d046d34ba 100644 --- a/pkg/loadbalance/consist_test.go +++ b/pkg/loadbalance/consist_test.go @@ -19,11 +19,10 @@ package loadbalance import ( "context" "fmt" - "math/rand" "strconv" - "strings" "testing" "time" + "unsafe" "github.com/bytedance/gopkg/lang/fastrand" @@ -53,15 +52,23 @@ func getRandomKey(ctx context.Context, request interface{}) string { return key } -func getRandomString(length int) string { - var resBuilder strings.Builder - resBuilder.Grow(length) - corpus := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - rand.Seed(time.Now().UnixNano() + int64(100)) - for i := 0; i < length; i++ { - resBuilder.WriteByte(corpus[rand.Intn(len(corpus))]) +func randRead(b []byte) { + p := unsafe.Pointer(&b[0]) + seed := uint64(time.Now().UnixNano()) + i := 0 + for ; i <= len(b)-8; i += 8 { + if i != 0 { + p = unsafe.Add(p, 8) + } + *((*uint64)(p)) = seed + seed <<= 13 // xorshift64 + seed >>= 7 + seed <<= 17 + } + for ; i < len(b); i++ { + b[i] = byte(seed) + seed >>= 8 } - return resBuilder.String() } func newTestConsistentHashOption() ConsistentHashOption { @@ -324,7 +331,7 @@ func BenchmarkNewConsistPicker_NoCache(bb *testing.B) { balancer := NewConsistBalancer(newTestConsistentHashOption()) ctx := context.Background() - for i := 0; i < 4; i++ { + for i := 0; i < 3; i++ { // when n=10000 it costs ~3 to run... fix me. bb.Run(fmt.Sprintf("%dins", n), func(b *testing.B) { inss := makeNInstances(n, 10) e := discovery.Result{ @@ -379,15 +386,6 @@ func BenchmarkNewConsistPicker(bb *testing.B) { } } -// BenchmarkConsistPicker_RandomDistributionKey -// BenchmarkConsistPicker_RandomDistributionKey/10ins -// BenchmarkConsistPicker_RandomDistributionKey/10ins-12 2417481 508.9 ns/op 48 B/op 1 allocs/op -// BenchmarkConsistPicker_RandomDistributionKey/100ins -// BenchmarkConsistPicker_RandomDistributionKey/100ins-12 2140726 534.6 ns/op 48 B/op 1 allocs/op -// BenchmarkConsistPicker_RandomDistributionKey/1000ins -// BenchmarkConsistPicker_RandomDistributionKey/1000ins-12 2848216. 407.7 ns/op 48 B/op 1 allocs/op -// BenchmarkConsistPicker_RandomDistributionKey/10000ins -// BenchmarkConsistPicker_RandomDistributionKey/10000ins-12 2701766 492.7 ns/op 48 B/op 1 allocs/op func BenchmarkConsistPicker_RandomDistributionKey(bb *testing.B) { n := 10 balancer := NewConsistBalancer(NewConsistentHashOption(getRandomKey)) @@ -400,17 +398,17 @@ func BenchmarkConsistPicker_RandomDistributionKey(bb *testing.B) { CacheKey: "test", Instances: inss, } + buf := make([]byte, 30) + randRead(buf) + s := *(*string)(unsafe.Pointer(&buf)) picker := balancer.GetPicker(e) - ctx := context.WithValue(context.Background(), keyCtxKey, getRandomString(30)) + ctx := context.WithValue(context.Background(), keyCtxKey, s) picker.Next(ctx, nil) picker.(internal.Reusable).Recycle() b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - b.Logf("round %d", i) - b.StopTimer() - ctx = context.WithValue(context.Background(), keyCtxKey, getRandomString(30)) - b.StartTimer() + randRead(buf) // it changes the data of `s` in ctx picker := balancer.GetPicker(e) picker.Next(ctx, nil) if r, ok := picker.(internal.Reusable); ok { From 84d18231663fb0c812249831a0ae7e45eaad3a98 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 26 Jul 2024 15:43:06 +0800 Subject: [PATCH 19/70] chore(test): fix xorshift64 in consist_test.go (#1462) --- pkg/loadbalance/consist_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/loadbalance/consist_test.go b/pkg/loadbalance/consist_test.go index 2d046d34ba..7fdc92e54d 100644 --- a/pkg/loadbalance/consist_test.go +++ b/pkg/loadbalance/consist_test.go @@ -61,9 +61,9 @@ func randRead(b []byte) { p = unsafe.Add(p, 8) } *((*uint64)(p)) = seed - seed <<= 13 // xorshift64 - seed >>= 7 - seed <<= 17 + seed ^= seed << 13 // xorshift64 + seed ^= seed >> 7 + seed ^= seed << 17 } for ; i < len(b); i++ { b[i] = byte(seed) From 2596bd88fa12497dafecfdc8a9518aacde7b41f5 Mon Sep 17 00:00:00 2001 From: Joway Date: Fri, 26 Jul 2024 17:19:21 +0800 Subject: [PATCH 20/70] chore: fix grpc keepalive test by start server responsiblly (#1463) --- .licenserc.yaml | 2 + .../trans/nphttp2/grpc/keepalive_test.go | 12 +++-- .../trans/nphttp2/grpc/transport_test.go | 50 ++++++++----------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/.licenserc.yaml b/.licenserc.yaml index f04bf297f9..b4796e6be7 100644 --- a/.licenserc.yaml +++ b/.licenserc.yaml @@ -23,7 +23,9 @@ header: - pkg/remote/trans/nphttp2/grpc/http2_server.go - pkg/remote/trans/nphttp2/grpc/http_util.go - pkg/remote/trans/nphttp2/grpc/keepalive.go + - pkg/remote/trans/nphttp2/grpc/keepalive_test.go - pkg/remote/trans/nphttp2/grpc/transport.go + - pkg/remote/trans/nphttp2/grpc/transport_test.go - pkg/remote/trans/nphttp2/metadata/metadata.go - pkg/remote/trans/nphttp2/status/status.go - pkg/remote/codec/protobuf/error.pb.go diff --git a/pkg/remote/trans/nphttp2/grpc/keepalive_test.go b/pkg/remote/trans/nphttp2/grpc/keepalive_test.go index e827e7e569..965cc0c462 100644 --- a/pkg/remote/trans/nphttp2/grpc/keepalive_test.go +++ b/pkg/remote/trans/nphttp2/grpc/keepalive_test.go @@ -248,11 +248,13 @@ func TestKeepaliveServerWithResponsiveClient(t *testing.T) { // logic is running even without any active streams. func TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { connCh := make(chan net.Conn, 1) + exitCh := make(chan struct{}) + defer func() { close(exitCh) }() client := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: ClientKeepalive{ Time: 250 * time.Millisecond, Timeout: 250 * time.Millisecond, PermitWithoutStream: true, - }}, connCh) + }}, connCh, exitCh) if client == nil { t.Fatalf("setUpWithNoPingServer failed, return nil client") } @@ -282,10 +284,12 @@ func TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { // active streams, and therefore the transport stays open. func TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { connCh := make(chan net.Conn, 1) + exitCh := make(chan struct{}) + defer func() { close(exitCh) }() client := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: ClientKeepalive{ Time: 250 * time.Millisecond, Timeout: 250 * time.Millisecond, - }}, connCh) + }}, connCh, exitCh) if client == nil { t.Fatalf("setUpWithNoPingServer failed, return nil client") } @@ -313,10 +317,12 @@ func TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { // transport even when there is an active stream. func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { connCh := make(chan net.Conn, 1) + exitCh := make(chan struct{}) + defer func() { close(exitCh) }() client := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: ClientKeepalive{ Time: 250 * time.Millisecond, Timeout: 250 * time.Millisecond, - }}, connCh) + }}, connCh, exitCh) if client == nil { t.Fatalf("setUpWithNoPingServer failed, return nil client") } diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 86570067af..636598bc76 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -465,55 +465,47 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy return server, ct.(*http2Client) } -func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) *http2Client { +func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn, exitCh chan struct{}) *http2Client { lis, err := net.Listen("tcp", "localhost:0") if err != nil { - fmt.Printf("Failed to listen: %v", err) - return nil + t.Fatalf("Failed to listen: %v", err) + } + // Launch a non responsive server and save the conn. + eventLoop, err := netpoll.NewEventLoop( + func(ctx context.Context, connection netpoll.Connection) error { return nil }, + netpoll.WithOnConnect(func(ctx context.Context, connection netpoll.Connection) context.Context { + connCh <- connection.(netpoll.Conn) + t.Logf("event loop on connect: %s", connection.RemoteAddr().String()) + return ctx + }), + ) + if err != nil { + t.Fatalf("Create netpoll event-loop failed: %v", err) } go func() { - exitCh := make(chan struct{}, 1) - // Launch a non responsive server. - eventLoop, err := netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error { - defer lis.Close() - connCh <- connection.(net.Conn) - exitCh <- struct{}{} - return nil - }) - if err != nil { - fmt.Printf("Create netpoll event-loop failed") - } - go func() { err = eventLoop.Serve(lis) if err != nil { - fmt.Printf("netpoll server exit failed, err=%v", err) + t.Errorf("netpoll server exit failed, err=%v", err) + return } }() - - select { - case <-exitCh: - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := eventLoop.Shutdown(ctx); err != nil { - fmt.Printf("netpoll server exit failed, err=%v", err) - } - default: - } + <-exitCh + // shutdown will called lis.Close() + _ = eventLoop.Shutdown(context.Background()) }() conn, err := netpoll.NewDialer().DialTimeout("tcp", lis.Addr().String(), time.Second) if err != nil { - fmt.Printf("Failed to dial: %v", err) + t.Fatalf("Failed to dial: %v", err) } tr, err := NewClientTransport(context.Background(), conn.(netpoll.Connection), copts, "mockDestService", func(GoAwayReason) {}, func() {}) if err != nil { // Server clean-up. - lis.Close() if conn, ok := <-connCh; ok { conn.Close() } - fmt.Printf("Failed to dial: %v", err) + t.Fatalf("Failed to dial: %v", err) } return tr.(*http2Client) } From 03c4c8ce6ed439f5b294b7cf9e7970214af04278 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Mon, 29 Jul 2024 14:31:01 +0800 Subject: [PATCH 21/70] refactor: deprecate bthrift, use cloudwego/gopkg (#1441) * tool: generates code for using cloudwego/gopkg * internal/mocks: updated thrift using the latest tool * pkg/utils/fastthrift: moved to cloudwego/gopkg * pkg/remote/codec/thrift: uses skipdecoder of cloudwego/gopkg * pkg/remote/codec/thrift: add fastcodec as a fallback, and always use fastcodec for Ex * pkg/generic: uses cloudwego/gopkg for pkg/protocol/bthrift: * type BinaryWriter = gopkgthrift.NocopyWriter * Removed ThriftFastCodec, moved to cloudwego/gopkg before releasing * Removed bthrift/exception.go, moved to cloudwego/gopkg before releasing * Removed bthrift/test, only for unknownfields testing * bthrift/unknown: moved to cloudwego/gopkg --- client/middlewares_test.go | 9 +- go.mod | 5 +- go.sum | 11 +- internal/mocks/thrift/k-test.go | 484 +- internal/mocks/thrift/utils.go | 27 +- pkg/generic/binary_test/generic_init.go | 28 +- pkg/generic/binary_test/generic_test.go | 32 +- pkg/generic/binarythrift_codec_test.go | 7 +- pkg/generic/thrift/base.go | 3 +- pkg/generic/thrift/http.go | 21 +- pkg/protocol/bthrift/exception.go | 231 - pkg/protocol/bthrift/exception_test.go | 102 - pkg/protocol/bthrift/interface.go | 14 +- pkg/protocol/bthrift/test/gen.sh | 4 - .../bthrift/test/kitex_gen/test/k-consts.go | 4 - .../bthrift/test/kitex_gen/test/k-test.go | 3955 ----------------- .../bthrift/test/kitex_gen/test/test.go | 1619 ------- pkg/protocol/bthrift/test/test.thrift | 93 - pkg/protocol/bthrift/test/unknown_test.go | 249 -- pkg/protocol/bthrift/unknown.go | 392 +- pkg/protocol/bthrift/unknown_test.go | 31 + pkg/remote/codec/thrift/skip_decoder.go | 167 - pkg/remote/codec/thrift/skip_decoder_test.go | 106 - pkg/remote/codec/thrift/thrift.go | 113 +- pkg/remote/codec/thrift/thrift_data.go | 81 +- pkg/remote/codec/thrift/thrift_data_test.go | 26 +- pkg/remote/codec/thrift/thrift_frugal.go | 32 +- pkg/remote/codec/thrift/thrift_frugal_test.go | 15 +- pkg/remote/codec/thrift/thrift_others.go | 10 - pkg/remote/codec/thrift/thrift_test.go | 111 +- pkg/remote/trans/netpollmux/control_frame.go | 28 +- pkg/utils/fastthrift/fastthrift.go | 63 +- pkg/utils/fastthrift/fastthrift_test.go | 49 - pkg/utils/thrift.go | 42 +- tool/internal_pkg/pluginmode/thriftgo/ast.go | 35 +- .../pluginmode/thriftgo/file_tpl.go | 6 +- .../pluginmode/thriftgo/struct_tpl.go | 178 +- 37 files changed, 606 insertions(+), 7777 deletions(-) delete mode 100644 pkg/protocol/bthrift/exception.go delete mode 100644 pkg/protocol/bthrift/exception_test.go delete mode 100755 pkg/protocol/bthrift/test/gen.sh delete mode 100644 pkg/protocol/bthrift/test/kitex_gen/test/k-consts.go delete mode 100644 pkg/protocol/bthrift/test/kitex_gen/test/k-test.go delete mode 100644 pkg/protocol/bthrift/test/kitex_gen/test/test.go delete mode 100644 pkg/protocol/bthrift/test/test.thrift delete mode 100644 pkg/protocol/bthrift/test/unknown_test.go create mode 100644 pkg/protocol/bthrift/unknown_test.go delete mode 100644 pkg/remote/codec/thrift/skip_decoder.go delete mode 100644 pkg/remote/codec/thrift/skip_decoder_test.go diff --git a/client/middlewares_test.go b/client/middlewares_test.go index 75b096be00..ae0caba38a 100644 --- a/client/middlewares_test.go +++ b/client/middlewares_test.go @@ -24,6 +24,8 @@ import ( "github.com/golang/mock/gomock" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/internal/mocks" mocksdiscovery "github.com/cloudwego/kitex/internal/mocks/discovery" mocksproxy "github.com/cloudwego/kitex/internal/mocks/proxy" @@ -32,7 +34,6 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/event" "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/proxy" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" @@ -140,14 +141,14 @@ func TestDefaultErrorHandler(t *testing.T) { reqCtx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) // Test TApplicationException - err := DefaultClientErrorHandler(context.Background(), bthrift.NewApplicationException(100, "mock")) + err := DefaultClientErrorHandler(context.Background(), thrift.NewApplicationException(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote]: mock", err.Error()) - var te *bthrift.ApplicationException + var te *thrift.ApplicationException ok := errors.As(err, &te) test.Assert(t, ok) test.Assert(t, te.TypeID() == 100) // Test TApplicationException with remote addr - err = ClientErrorHandlerWithAddr(reqCtx, bthrift.NewApplicationException(100, "mock")) + err = ClientErrorHandlerWithAddr(reqCtx, thrift.NewApplicationException(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote-"+tcpAddrStr+"]: mock", err.Error()) ok = errors.As(err, &te) test.Assert(t, ok) diff --git a/go.mod b/go.mod index c5248a544e..438de7dae9 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,13 @@ go 1.17 require ( github.com/apache/thrift v0.13.0 - github.com/bytedance/gopkg v0.0.0-20240514070511-01b2cbcf35e1 + github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8 github.com/bytedance/sonic v1.11.8 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.2.9 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.1.15 + github.com/cloudwego/gopkg v0.0.0-20240722090221-969ae87c75ac github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 @@ -18,7 +19,7 @@ require ( github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 github.com/jhump/protoreflect v1.8.2 github.com/json-iterator/go v1.1.12 - github.com/stretchr/testify v1.8.2 + github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.9.3 golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 diff --git a/go.sum b/go.sum index a276e43f67..3c40301dd4 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.0.0-20240514070511-01b2cbcf35e1 h1:rT7Mm6uUpHeZQzfs2v0Mlj0SL02CzyVi+EB7VYPM/z4= -github.com/bytedance/gopkg v0.0.0-20240514070511-01b2cbcf35e1/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8 h1:rDwLxYTMoKHaw4cS0bQhaTZnkXp5e6ediCggGcRD/CA= +github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.8 h1:Zw/j1KfiS+OYTi9lyB3bb0CFxPJVkM17k1wyDG32LRA= github.com/bytedance/sonic v1.11.8/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= @@ -36,6 +36,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.1.15 h1:LC55UJKhQPMFVjDPbE+LJcF7etZjSx6uokG1tk0wPK0= github.com/cloudwego/frugal v0.1.15/go.mod h1:26kU1r18vA8vRg12c66XPDlfv1GQHDbE1RpusipXfcI= +github.com/cloudwego/gopkg v0.0.0-20240722090221-969ae87c75ac h1:B7iK0zQ34wJkmNixXDHMHB+WrZJYadTAJSJkM21RZ6U= +github.com/cloudwego/gopkg v0.0.0-20240722090221-969ae87c75ac/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.0.9/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= @@ -149,6 +151,7 @@ github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= @@ -156,8 +159,10 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.9.3 h1:hqzS9wAHMO+KVBBkLxYdkEeeFHuqr95GfClRLKlgK0E= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= diff --git a/internal/mocks/thrift/k-test.go b/internal/mocks/thrift/k-test.go index f21df2427a..b1b37ae2b9 100644 --- a/internal/mocks/thrift/k-test.go +++ b/internal/mocks/thrift/k-test.go @@ -24,8 +24,7 @@ import ( "reflect" "strings" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift" ) // unused protection @@ -34,8 +33,7 @@ var ( _ = (*bytes.Buffer)(nil) _ = (*strings.Builder)(nil) _ = reflect.Type(nil) - _ = thrift.TProtocol(nil) - _ = bthrift.BinaryWriter(nil) + _ = thrift.STOP ) func (p *MockReq) FastRead(buf []byte) (int, error) { @@ -44,14 +42,8 @@ func (p *MockReq) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -68,7 +60,7 @@ func (p *MockReq) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -82,7 +74,7 @@ func (p *MockReq) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -96,58 +88,39 @@ func (p *MockReq) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError } return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockReq[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockReq[fieldId]), err) SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) } func (p *MockReq) FastReadField1(buf []byte) (int, error) { offset := 0 var _field string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Msg = _field return offset, nil @@ -156,7 +129,7 @@ func (p *MockReq) FastReadField1(buf []byte) (int, error) { func (p *MockReq) FastReadField2(buf []byte) (int, error) { offset := 0 - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) offset += l if err != nil { return offset, err @@ -164,32 +137,23 @@ func (p *MockReq) FastReadField2(buf []byte) (int, error) { _field := make(map[string]string, size) for i := 0; i < size; i++ { var _key string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _key = v - } var _val string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _val = v - } _field[_key] = _val } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } p.StrMap = _field return offset, nil } @@ -197,7 +161,7 @@ func (p *MockReq) FastReadField2(buf []byte) (int, error) { func (p *MockReq) FastReadField3(buf []byte) (int, error) { offset := 0 - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) offset += l if err != nil { return offset, err @@ -205,22 +169,15 @@ func (p *MockReq) FastReadField3(buf []byte) (int, error) { _field := make([]string, 0, size) for i := 0; i < size; i++ { var _elem string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _elem = v - } _field = append(_field, _elem) } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } p.StrList = _field return offset, nil } @@ -230,104 +187,92 @@ func (p *MockReq) FastWrite(buf []byte) int { return 0 } -func (p *MockReq) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockReq) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "MockReq") if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField2(buf[offset:], binaryWriter) - offset += p.fastWriteField3(buf[offset:], binaryWriter) + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *MockReq) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("MockReq") if p != nil { l += p.field1Length() l += p.field2Length() l += p.field3Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *MockReq) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockReq) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Msg", thrift.STRING, 1) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Msg) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, p.Msg) return offset } -func (p *MockReq) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockReq) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "strMap", thrift.MAP, 2) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.MAP, 2) mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRING, 0) + offset += thrift.Binary.MapBeginLength() var length int for k, v := range p.StrMap { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, k) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRING, length) return offset } -func (p *MockReq) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockReq) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "strList", thrift.LIST, 3) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 3) listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRING, 0) + offset += thrift.Binary.ListBeginLength() var length int for _, v := range p.StrList { length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) return offset } func (p *MockReq) field1Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("Msg", thrift.STRING, 1) - l += bthrift.Binary.StringLengthNocopy(p.Msg) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(p.Msg) return l } func (p *MockReq) field2Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("strMap", thrift.MAP, 2) - l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRING, len(p.StrMap)) + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() for k, v := range p.StrMap { + _, _ = k, v - l += bthrift.Binary.StringLengthNocopy(k) - l += bthrift.Binary.StringLengthNocopy(v) + l += thrift.Binary.StringLengthNocopy(k) + l += thrift.Binary.StringLengthNocopy(v) } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() return l } func (p *MockReq) field3Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("strList", thrift.LIST, 3) - l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.StrList)) + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() for _, v := range p.StrList { - l += bthrift.Binary.StringLengthNocopy(v) + _ = v + l += thrift.Binary.StringLengthNocopy(v) } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() return l } @@ -337,14 +282,8 @@ func (p *Exception) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -361,7 +300,7 @@ func (p *Exception) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -375,58 +314,39 @@ func (p *Exception) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError } return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Exception[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Exception[fieldId]), err) SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) } func (p *Exception) FastReadField1(buf []byte) (int, error) { offset := 0 var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Code = _field return offset, nil @@ -436,13 +356,11 @@ func (p *Exception) FastReadField255(buf []byte) (int, error) { offset := 0 var _field string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Msg = _field return offset, nil @@ -453,59 +371,51 @@ func (p *Exception) FastWrite(buf []byte) int { return 0 } -func (p *Exception) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Exception) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Exception") if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField255(buf[offset:], binaryWriter) + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField255(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *Exception) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("Exception") if p != nil { l += p.field1Length() l += p.field255Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *Exception) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Exception) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "code", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Code) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 1) + offset += thrift.Binary.WriteI32(buf[offset:], p.Code) return offset } -func (p *Exception) fastWriteField255(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Exception) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "msg", thrift.STRING, 255) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Msg) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 255) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, p.Msg) return offset } func (p *Exception) field1Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("code", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.Code) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() return l } func (p *Exception) field255Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("msg", thrift.STRING, 255) - l += bthrift.Binary.StringLengthNocopy(p.Msg) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(p.Msg) return l } @@ -515,14 +425,8 @@ func (p *MockTestArgs) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -539,45 +443,28 @@ func (p *MockTestArgs) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError } return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestArgs[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestArgs[fieldId]), err) SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) } func (p *MockTestArgs) FastReadField1(buf []byte) (int, error) { @@ -597,41 +484,35 @@ func (p *MockTestArgs) FastWrite(buf []byte) int { return 0 } -func (p *MockTestArgs) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockTestArgs) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Test_args") if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) + offset += p.fastWriteField1(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *MockTestArgs) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("Test_args") if p != nil { l += p.field1Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *MockTestArgs) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockTestArgs) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "req", thrift.STRUCT, 1) - offset += p.Req.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.Req.FastWriteNocopy(buf[offset:], w) return offset } func (p *MockTestArgs) field1Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("req", thrift.STRUCT, 1) + l += thrift.Binary.FieldBeginLength() l += p.Req.BLength() - l += bthrift.Binary.FieldEndLength() return l } @@ -641,14 +522,8 @@ func (p *MockTestResult) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -665,57 +540,39 @@ func (p *MockTestResult) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError } return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestResult[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockTestResult[fieldId]), err) SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) } func (p *MockTestResult) FastReadField0(buf []byte) (int, error) { offset := 0 var _field *string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l _field = &v - } p.Success = _field return offset, nil @@ -726,34 +583,29 @@ func (p *MockTestResult) FastWrite(buf []byte) int { return 0 } -func (p *MockTestResult) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockTestResult) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Test_result") if p != nil { - offset += p.fastWriteField0(buf[offset:], binaryWriter) + offset += p.fastWriteField0(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *MockTestResult) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("Test_result") if p != nil { l += p.field0Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *MockTestResult) fastWriteField0(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockTestResult) fastWriteField0(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetSuccess() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "success", thrift.STRING, 0) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, *p.Success) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 0) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Success) } return offset } @@ -761,9 +613,8 @@ func (p *MockTestResult) fastWriteField0(buf []byte, binaryWriter bthrift.Binary func (p *MockTestResult) field0Length() int { l := 0 if p.IsSetSuccess() { - l += bthrift.Binary.FieldBeginLength("success", thrift.STRING, 0) - l += bthrift.Binary.StringLengthNocopy(*p.Success) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Success) } return l } @@ -774,14 +625,8 @@ func (p *MockExceptionTestArgs) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -798,45 +643,28 @@ func (p *MockExceptionTestArgs) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError } return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestArgs[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestArgs[fieldId]), err) SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) } func (p *MockExceptionTestArgs) FastReadField1(buf []byte) (int, error) { @@ -856,41 +684,35 @@ func (p *MockExceptionTestArgs) FastWrite(buf []byte) int { return 0 } -func (p *MockExceptionTestArgs) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockExceptionTestArgs) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "ExceptionTest_args") if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) + offset += p.fastWriteField1(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *MockExceptionTestArgs) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("ExceptionTest_args") if p != nil { l += p.field1Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *MockExceptionTestArgs) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockExceptionTestArgs) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "req", thrift.STRUCT, 1) - offset += p.Req.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.Req.FastWriteNocopy(buf[offset:], w) return offset } func (p *MockExceptionTestArgs) field1Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("req", thrift.STRUCT, 1) + l += thrift.Binary.FieldBeginLength() l += p.Req.BLength() - l += bthrift.Binary.FieldEndLength() return l } @@ -900,14 +722,8 @@ func (p *MockExceptionTestResult) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -924,7 +740,7 @@ func (p *MockExceptionTestResult) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -938,57 +754,39 @@ func (p *MockExceptionTestResult) FastRead(buf []byte) (int, error) { goto ReadFieldError } } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError } return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestResult[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MockExceptionTestResult[fieldId]), err) SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) } func (p *MockExceptionTestResult) FastReadField0(buf []byte) (int, error) { offset := 0 var _field *string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l _field = &v - } p.Success = _field return offset, nil @@ -1011,46 +809,40 @@ func (p *MockExceptionTestResult) FastWrite(buf []byte) int { return 0 } -func (p *MockExceptionTestResult) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockExceptionTestResult) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "ExceptionTest_result") if p != nil { - offset += p.fastWriteField0(buf[offset:], binaryWriter) - offset += p.fastWriteField1(buf[offset:], binaryWriter) + offset += p.fastWriteField0(buf[offset:], w) + offset += p.fastWriteField1(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *MockExceptionTestResult) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("ExceptionTest_result") if p != nil { l += p.field0Length() l += p.field1Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *MockExceptionTestResult) fastWriteField0(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockExceptionTestResult) fastWriteField0(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetSuccess() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "success", thrift.STRING, 0) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, *p.Success) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 0) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Success) } return offset } -func (p *MockExceptionTestResult) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *MockExceptionTestResult) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetErr() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "err", thrift.STRUCT, 1) - offset += p.Err.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.Err.FastWriteNocopy(buf[offset:], w) } return offset } @@ -1058,9 +850,8 @@ func (p *MockExceptionTestResult) fastWriteField1(buf []byte, binaryWriter bthri func (p *MockExceptionTestResult) field0Length() int { l := 0 if p.IsSetSuccess() { - l += bthrift.Binary.FieldBeginLength("success", thrift.STRING, 0) - l += bthrift.Binary.StringLengthNocopy(*p.Success) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Success) } return l } @@ -1068,9 +859,8 @@ func (p *MockExceptionTestResult) field0Length() int { func (p *MockExceptionTestResult) field1Length() int { l := 0 if p.IsSetErr() { - l += bthrift.Binary.FieldBeginLength("err", thrift.STRUCT, 1) + l += thrift.Binary.FieldBeginLength() l += p.Err.BLength() - l += bthrift.Binary.FieldEndLength() } return l } diff --git a/internal/mocks/thrift/utils.go b/internal/mocks/thrift/utils.go index 13a79862d3..5d080bc513 100644 --- a/internal/mocks/thrift/utils.go +++ b/internal/mocks/thrift/utils.go @@ -20,25 +20,26 @@ import ( "errors" "io" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift" + + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // ApacheCodecAdapter converts a fastcodec struct to apache codec type ApacheCodecAdapter struct { - p bthrift.ThriftFastCodec + p thrift.FastCodec } -// Write implements thrift.TStruct -func (p ApacheCodecAdapter) Write(tp thrift.TProtocol) error { +// Write implements athrift.TStruct +func (p ApacheCodecAdapter) Write(tp athrift.TProtocol) error { b := make([]byte, p.p.BLength()) b = b[:p.p.FastWriteNocopy(b, nil)] _, err := tp.Transport().Write(b) return err } -// Read implements thrift.TStruct -func (p ApacheCodecAdapter) Read(tp thrift.TProtocol) error { +// Read implements athrift.TStruct +func (p ApacheCodecAdapter) Read(tp athrift.TProtocol) error { var err error var b []byte trans := tp.Transport() @@ -54,12 +55,16 @@ func (p ApacheCodecAdapter) Read(tp thrift.TProtocol) error { return err } -// ToApacheCodec converts a bthrift.ThriftFastCodec to thrift.TStruct -func ToApacheCodec(p bthrift.ThriftFastCodec) thrift.TStruct { +// ToApacheCodec converts a thrift.FastCodec to athrift.TStruct +func ToApacheCodec(p thrift.FastCodec) athrift.TStruct { return ApacheCodecAdapter{p: p} } -// UnpackApacheCodec unpacks ToApacheCodec +// UnpackApacheCodec unpacks the value returned by `ToApacheCodec` func UnpackApacheCodec(v interface{}) interface{} { - return v.(ApacheCodecAdapter).p + a, ok := v.(ApacheCodecAdapter) + if ok { + return a.p + } + return v } diff --git a/pkg/generic/binary_test/generic_init.go b/pkg/generic/binary_test/generic_init.go index 6b5ae69fe2..05aece2460 100644 --- a/pkg/generic/binary_test/generic_init.go +++ b/pkg/generic/binary_test/generic_init.go @@ -19,22 +19,21 @@ package test import ( "context" - "encoding/binary" "errors" "fmt" "net" "time" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/genericclient" kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/kerrors" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/transmeta" - "github.com/cloudwego/kitex/pkg/utils/fastthrift" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" "github.com/cloudwego/kitex/transport" @@ -102,7 +101,7 @@ func (g *GenericServiceMockImpl) GenericCall(ctx context.Context, method string, buf := request.([]byte) var args2 kt.MockTestArgs - mth, seqID, err := fastthrift.UnmarshalMsg(buf, &args2) + mth, seqID, err := thrift.UnmarshalFastMsg(buf, &args2) if err != nil { return nil, err } @@ -115,7 +114,7 @@ func (g *GenericServiceMockImpl) GenericCall(ctx context.Context, method string, result := kt.NewMockTestResult() result.Success = &resp - buf, err = fastthrift.MarshalMsg(mth, fastthrift.REPLY, seqID, result) + buf, err = thrift.MarshalFastMsg(mth, thrift.REPLY, seqID, result) return buf, err } @@ -190,16 +189,11 @@ func (m *MockImpl) ExceptionTest(ctx context.Context, req *kt.MockReq) (r string } func genBinaryResp(method string) []byte { - idx := 0 - buf := make([]byte, 12+len(method)+len(respMsg)) - binary.BigEndian.PutUint32(buf, thrift.VERSION_1) - idx += 4 - binary.BigEndian.PutUint32(buf[idx:idx+4], uint32(len(method))) - idx += 4 - copy(buf[idx:idx+len(method)], method) - idx += len(method) - binary.BigEndian.PutUint32(buf[idx:idx+4], 100) - idx += 4 - copy(buf[idx:idx+len(respMsg)], respMsg) - return buf + // no idea for respMsg part, it's not binary protocol. + // DO NOT TOUCH IT or you may need to change the tests as well + n := thrift.Binary.MessageBeginLength(method, 0, 0) + len(respMsg) + b := make([]byte, 0, n) + b = thrift.Binary.AppendMessageBegin(b, method, 0, 100) + b = append(b, respMsg...) + return b } diff --git a/pkg/generic/binary_test/generic_test.go b/pkg/generic/binary_test/generic_test.go index 57d0c34b56..edfffd6fd5 100644 --- a/pkg/generic/binary_test/generic_test.go +++ b/pkg/generic/binary_test/generic_test.go @@ -18,7 +18,6 @@ package test import ( "context" - "encoding/binary" "net" "runtime" "runtime/debug" @@ -26,6 +25,8 @@ import ( "testing" "time" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/callopt" "github.com/cloudwego/kitex/client/genericclient" @@ -33,8 +34,6 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/kerrors" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/utils/fastthrift" "github.com/cloudwego/kitex/server" ) @@ -111,7 +110,7 @@ func rawThriftBinaryMockReq(t *testing.T) { args.Req = req // encode - buf, err := fastthrift.MarshalMsg("Test", fastthrift.CALL, 100, args) + buf, err := thrift.MarshalFastMsg("Test", thrift.CALL, 100, args) test.Assert(t, err == nil, err) resp, err := cli.GenericCall(context.Background(), "Test", buf) @@ -120,7 +119,7 @@ func rawThriftBinaryMockReq(t *testing.T) { // decode buf = resp.([]byte) var result kt.MockTestResult - method, seqID, err := fastthrift.UnmarshalMsg(buf, &result) + method, seqID, err := thrift.UnmarshalFastMsg(buf, &result) test.Assert(t, err == nil, err) test.Assert(t, method == "Test", method) test.Assert(t, seqID != 100, seqID) @@ -147,7 +146,7 @@ func rawThriftBinary2NormalServer(t *testing.T) { args.Req = req // encode - buf, err := fastthrift.MarshalMsg("Test", fastthrift.CALL, 100, args) + buf, err := thrift.MarshalFastMsg("Test", thrift.CALL, 100, args) test.Assert(t, err == nil, err) resp, err := cli.GenericCall(context.Background(), "Test", buf, callopt.WithRPCTimeout(100*time.Second)) @@ -156,7 +155,7 @@ func rawThriftBinary2NormalServer(t *testing.T) { // decode buf = resp.([]byte) var result kt.MockTestResult - method, seqID, err := fastthrift.UnmarshalMsg(buf, &result) + method, seqID, err := thrift.UnmarshalFastMsg(buf, &result) test.Assert(t, err == nil, err) test.Assert(t, method == "Test", method) // seqID会在kitex中覆盖,避免TTHeader和Payload codec 不一致问题 @@ -184,18 +183,13 @@ func initMockServer(handler kt.Mock) server.Server { } func genBinaryReqBuf(method string) []byte { - idx := 0 - buf := make([]byte, 12+len(method)+len(reqMsg)) - binary.BigEndian.PutUint32(buf, thrift.VERSION_1) - idx += 4 - binary.BigEndian.PutUint32(buf[idx:idx+4], uint32(len(method))) - idx += 4 - copy(buf[idx:idx+len(method)], method) - idx += len(method) - binary.BigEndian.PutUint32(buf[idx:idx+4], 100) - idx += 4 - copy(buf[idx:idx+len(reqMsg)], reqMsg) - return buf + // no idea for reqMsg part, it's not binary protocol. + // DO NOT TOUCH IT or you may need to change the tests as well + n := thrift.Binary.MessageBeginLength(method, 0, 0) + len(reqMsg) + b := make([]byte, 0, n) + b = thrift.Binary.AppendMessageBegin(b, method, 0, 100) + b = append(b, reqMsg...) + return b } func TestBinaryThriftGenericClientClose(t *testing.T) { diff --git a/pkg/generic/binarythrift_codec_test.go b/pkg/generic/binarythrift_codec_test.go index 0a211d3994..1c79a1df20 100644 --- a/pkg/generic/binarythrift_codec_test.go +++ b/pkg/generic/binarythrift_codec_test.go @@ -20,12 +20,13 @@ import ( "context" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" + kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/utils/fastthrift" ) func TestBinaryThriftCodec(t *testing.T) { @@ -33,7 +34,7 @@ func TestBinaryThriftCodec(t *testing.T) { args := kt.NewMockTestArgs() args.Req = req // encode - buf, err := fastthrift.MarshalMsg("mock", fastthrift.CALL, 100, args) + buf, err := thrift.MarshalFastMsg("mock", thrift.CALL, 100, args) test.Assert(t, err == nil, err) btc := &binaryThriftCodec{thriftCodec} @@ -90,7 +91,7 @@ func TestBinaryThriftCodec(t *testing.T) { test.Assert(t, seqID == 1, seqID) var req2 kt.MockTestArgs - method, seqID2, err2 := fastthrift.UnmarshalMsg(reqBuf, &req2) + method, seqID2, err2 := thrift.UnmarshalFastMsg(reqBuf, &req2) test.Assert(t, err2 == nil, err) test.Assert(t, seqID2 == 1, seqID) test.Assert(t, method == "mock", method) diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go index 0139ca24c3..5eb02672a3 100644 --- a/pkg/generic/thrift/base.go +++ b/pkg/generic/thrift/base.go @@ -19,7 +19,8 @@ package thrift import ( "fmt" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" + bthrift "github.com/cloudwego/gopkg/protocol/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) diff --git a/pkg/generic/thrift/http.go b/pkg/generic/thrift/http.go index 1311fdb7ec..f72df178d3 100644 --- a/pkg/generic/thrift/http.go +++ b/pkg/generic/thrift/http.go @@ -24,11 +24,11 @@ import ( "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" + "github.com/cloudwego/gopkg/protocol/thrift" jsoniter "github.com/json-iterator/go" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" ) @@ -79,7 +79,7 @@ func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.O } // originalWrite ... -func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error { +func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out athrift.TProtocol, msg interface{}, requestBase *Base) error { req := msg.(*descriptor.HTTPRequest) if req.Body == nil && len(req.RawBody) != 0 { if err := customJson.Unmarshal(req.RawBody, &req.Body); err != nil { @@ -131,7 +131,7 @@ func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options) { } // Read ... -func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in athrift.TProtocol) (interface{}, error) { // fallback logic if !r.dynamicgoEnabled || dataLen == 0 { return r.originalRead(ctx, method, in) @@ -140,18 +140,13 @@ func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient boo if !ok { return nil, perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol") } - sName, err := in.ReadStructBegin() - if err != nil { - return nil, err - } - sBeginLen := bthrift.Binary.StructBeginLength(sName) // TODO: support exception field - fName, typeId, id, err := in.ReadFieldBegin() + _, _, id, err := in.ReadFieldBegin() if err != nil { return nil, err } - fBeginLen := bthrift.Binary.FieldBeginLength(fName, typeId, id) - transBuf, err := tProt.ByteBuffer().ReadBinary(dataLen - sBeginLen - fBeginLen) + fBeginLen := thrift.Binary.FieldBeginLength() + transBuf, err := tProt.ByteBuffer().ReadBinary(dataLen - fBeginLen) if err != nil { return nil, err } @@ -181,7 +176,7 @@ func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient boo return resp, nil } -func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in athrift.TProtocol) (interface{}, error) { fnDsc, err := r.svc.LookupFunctionByMethod(method) if err != nil { return nil, err diff --git a/pkg/protocol/bthrift/exception.go b/pkg/protocol/bthrift/exception.go deleted file mode 100644 index 598a47ec9d..0000000000 --- a/pkg/protocol/bthrift/exception.go +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package bthrift - -import ( - "errors" - "fmt" - - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -// ApplicationException is for replacing apache.TApplicationException -// it implements ThriftFastCodec interface. -type ApplicationException struct { - t int32 - m string -} - -// check interface only. TO BE REMOVED in the future -var _ thrift.TApplicationException = &ApplicationException{} - -// NewApplicationException creates an ApplicationException instance -func NewApplicationException(t int32, msg string) *ApplicationException { - return &ApplicationException{t: t, m: msg} -} - -// Msg ... -func (e *ApplicationException) Msg() string { return e.m } - -// TypeID ... -func (e *ApplicationException) TypeID() int32 { return e.t } - -// TypeId ... for apache ApplicationException compatibility -func (e *ApplicationException) TypeId() int32 { return e.t } - -// BLength returns the len of encoded buffer. -func (e *ApplicationException) BLength() int { - // Msg Field: 1 (type) + 2 (id) + 4(strlen) + len(m) - // Type Field: 1 (type) + 2 (id) + 4(ex type) - // STOP: 1 byte - return (1 + 2 + 4 + len(e.m)) + (1 + 2 + 4) + 1 -} - -// FastRead ... -func (e *ApplicationException) FastRead(b []byte) (off int, err error) { - for i := 0; i < 2; i++ { - _, tp, id, l, err := Binary.ReadFieldBegin(b[off:]) - if err != nil { - return 0, err - } - off += l - switch { - case id == 1 && tp == thrift.STRING: // Msg - e.m, l, err = Binary.ReadString(b[off:]) - case id == 2 && tp == thrift.I32: // TypeID - e.t, l, err = Binary.ReadI32(b[off:]) - default: - l, err = Binary.Skip(b, tp) - } - if err != nil { - return 0, err - } - off += l - } - v, l, err := Binary.ReadByte(b[off:]) - if err != nil { - return 0, err - } - if v != thrift.STOP { - return 0, fmt.Errorf("expects thrift.STOP, found: %d", v) - } - off += l - return off, nil -} - -// FastWrite ... -func (e *ApplicationException) FastWrite(b []byte) (off int) { - off += Binary.WriteFieldBegin(b[off:], "", thrift.STRING, 1) - off += Binary.WriteString(b[off:], e.m) - off += Binary.WriteFieldBegin(b[off:], "", thrift.I32, 2) - off += Binary.WriteI32(b[off:], e.t) - off += Binary.WriteByte(b[off:], thrift.STOP) - return off -} - -// FastWriteNocopy ... XXX: we deprecated XXXNocopy, simply using FastWrite is OK. -func (e *ApplicationException) FastWriteNocopy(b []byte, binaryWriter BinaryWriter) int { - return e.FastWrite(b) -} - -// Read implements Read interface of TStruct -// it only supports binary protocol. -// Deprecated: use FastRead instead -func (e *ApplicationException) Read(in thrift.TProtocol) error { - for { - _, ttype, id, err := in.ReadFieldBegin() - if err != nil { - return err - } - if ttype == thrift.STOP { - break - } - switch { - case id == 1 && ttype == thrift.STRING: - e.m, err = in.ReadString() - if err != nil { - return err - } - case id == 2 && ttype == thrift.I32: - e.t, err = in.ReadI32() - if err != nil { - return err - } - default: - if err = thrift.SkipDefaultDepth(in, ttype); err != nil { - return err - } - } - } - return nil -} - -// Write implements Write interface of TStruct -// it only supports binary protocol. -// Deprecated: use FastWrite instead -func (e *ApplicationException) Write(out thrift.TProtocol) error { - if err := out.WriteFieldBegin("message", thrift.STRING, 1); err != nil { - return err - } - if err := out.WriteString(e.m); err != nil { - return err - } - if err := out.WriteFieldBegin("type", thrift.I32, 2); err != nil { - return err - } - if err := out.WriteI32(e.t); err != nil { - return err - } - return out.WriteFieldStop() -} - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go -var defaultApplicationExceptionMessage = map[int32]string{ - thrift.UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception", - thrift.UNKNOWN_METHOD: "unknown method", - thrift.INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type", - thrift.WRONG_METHOD_NAME: "wrong method name", - thrift.BAD_SEQUENCE_ID: "bad sequence ID", - thrift.MISSING_RESULT: "missing result", - thrift.INTERNAL_ERROR: "unknown internal error", - thrift.PROTOCOL_ERROR: "unknown protocol error", - thrift.INVALID_TRANSFORM: "Invalid transform", - thrift.INVALID_PROTOCOL: "Invalid protocol", - thrift.UNSUPPORTED_CLIENT_TYPE: "Unsupported client type", -} - -// Error implements apache.Exception -func (e *ApplicationException) Error() string { - if e.m != "" { - return e.m - } - if m, ok := defaultApplicationExceptionMessage[e.t]; ok { - return m - } - return fmt.Sprintf("unknown exception type [%d]", e.t) -} - -// TransportException is for replacing apache.TransportException -// it implements ThriftFastCodec interface. -type TransportException struct { - ApplicationException // same implementation ... -} - -// NewTransportException ... -func NewTransportException(t int32, m string) *TransportException { - ret := TransportException{} - ret.t = t - ret.m = m - return &ret -} - -// ProtocolException is for replacing apache.ProtocolException -// it implements ThriftFastCodec interface. -type ProtocolException struct { - ApplicationException // same implementation ... -} - -// NewTransportException ... -func NewProtocolException(t int32, m string) *ProtocolException { - ret := ProtocolException{} - ret.t = t - ret.m = m - return &ret -} - -// Generic Thrift exception with TypeId method -type tException interface { - Error() string - TypeId() int32 -} - -// Prepends additional information to an error without losing the Thrift exception interface -func PrependError(prepend string, err error) error { - if t, ok := err.(*TransportException); ok { - return NewTransportException(t.TypeID(), prepend+t.Error()) - } - if t, ok := err.(*ProtocolException); ok { - return NewProtocolException(t.TypeID(), prepend+err.Error()) - } - if t, ok := err.(*ApplicationException); ok { - return NewApplicationException(t.TypeID(), prepend+t.Error()) - } - if t, ok := err.(tException); ok { // apache thrift exception? - return NewApplicationException(t.TypeId(), prepend+t.Error()) - } - return errors.New(prepend + err.Error()) -} diff --git a/pkg/protocol/bthrift/exception_test.go b/pkg/protocol/bthrift/exception_test.go deleted file mode 100644 index 574653d2cf..0000000000 --- a/pkg/protocol/bthrift/exception_test.go +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package bthrift - -import ( - "bytes" - "errors" - "testing" - - "github.com/cloudwego/kitex/internal/test" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -func TestApplicationException(t *testing.T) { - ex1 := NewApplicationException(1, "t1") - b := make([]byte, ex1.BLength()) - n := ex1.FastWrite(b) - test.Assert(t, n == len(b)) - - ex2 := NewApplicationException(0, "") - n, err := ex2.FastRead(b) - test.Assert(t, err == nil, err) - test.Assert(t, n == len(b), n) - test.Assert(t, ex2.TypeID() == 1) - test.Assert(t, ex2.Msg() == "t1") - - // ================= - // the code below, it's for compatibility test only. - // it can be removed in the future along with Read/Write method - - trans := thrift.NewTMemoryBufferLen(100) - proto := thrift.NewTBinaryProtocol(trans, true, true) - ex9 := thrift.NewTApplicationException(1, "t1") - err = ex9.Write(proto) - test.Assert(t, err == nil, err) - test.Assert(t, bytes.Equal(b, trans.Bytes())) - - trans = thrift.NewTMemoryBufferLen(100) - proto = thrift.NewTBinaryProtocol(trans, true, true) - ex3 := NewApplicationException(1, "t1") - err = ex3.Write(proto) - test.Assert(t, err == nil, err) - test.Assert(t, bytes.Equal(b, trans.Bytes())) - - ex4 := NewApplicationException(0, "") - err = ex4.Read(proto) - test.Assert(t, err == nil, err) - test.Assert(t, ex4.TypeID() == 1) - test.Assert(t, ex4.Msg() == "t1") -} - -func TestPrependError(t *testing.T) { - var ok bool - ex0 := NewTransportException(1, "world") - err0 := PrependError("hello ", ex0) - ex0, ok = err0.(*TransportException) - test.Assert(t, ok) - test.Assert(t, ex0.TypeID() == 1) - test.Assert(t, ex0.Error() == "hello world") - - ex1 := NewProtocolException(2, "world") - err1 := PrependError("hello ", ex1) - ex1, ok = err1.(*ProtocolException) - test.Assert(t, ok) - test.Assert(t, ex1.TypeID() == 2) - test.Assert(t, ex1.Error() == "hello world") - - ex2 := NewApplicationException(3, "world") - err2 := PrependError("hello ", ex2) - ex2, ok = err2.(*ApplicationException) - test.Assert(t, ok) - test.Assert(t, ex2.TypeID() == 3) - test.Assert(t, ex2.Error() == "hello world") - - err3 := PrependError("hello ", errors.New("world")) - _, ok = err3.(tException) - test.Assert(t, !ok) - test.Assert(t, err3.Error() == "hello world") - - // the code below, it's for compatibility test only. - // it can be removed in the future along with Read/Write method - ex9 := thrift.NewTApplicationException(9, "world") - err9 := PrependError("hello ", ex9) - ex, ok := err9.(tException) - test.Assert(t, ok) - test.Assert(t, ex.TypeId() == 9) - test.Assert(t, ex.Error() == "hello world") -} diff --git a/pkg/protocol/bthrift/interface.go b/pkg/protocol/bthrift/interface.go index 06965813bc..75fa0ce951 100644 --- a/pkg/protocol/bthrift/interface.go +++ b/pkg/protocol/bthrift/interface.go @@ -18,20 +18,14 @@ package bthrift import ( + gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // BinaryWriter . -type BinaryWriter interface { - WriteDirect(b []byte, remainCap int) error -} - -// ThriftFastCodec represents the interface of thrift fastcodec generated structs -type ThriftFastCodec interface { - BLength() int - FastWriteNocopy(buf []byte, binaryWriter BinaryWriter) int - FastRead(buf []byte) (int, error) -} +// Deprecated: use `github.com/cloudwego/gopkg/protocol/thrift.NocopyWriter` +type BinaryWriter = gopkgthrift.NocopyWriter // BTProtocol . type BTProtocol interface { diff --git a/pkg/protocol/bthrift/test/gen.sh b/pkg/protocol/bthrift/test/gen.sh deleted file mode 100755 index e064be9d98..0000000000 --- a/pkg/protocol/bthrift/test/gen.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -kitex -thrift no_default_serdes -module github.com/cloudwego/kitex -thrift keep_unknown_fields test.thrift - diff --git a/pkg/protocol/bthrift/test/kitex_gen/test/k-consts.go b/pkg/protocol/bthrift/test/kitex_gen/test/k-consts.go deleted file mode 100644 index 2f0ccc68e6..0000000000 --- a/pkg/protocol/bthrift/test/kitex_gen/test/k-consts.go +++ /dev/null @@ -1,4 +0,0 @@ -package test - -// KitexUnusedProtection is used to prevent 'imported and not used' error. -var KitexUnusedProtection = struct{}{} diff --git a/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go b/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go deleted file mode 100644 index fd5929e322..0000000000 --- a/pkg/protocol/bthrift/test/kitex_gen/test/k-test.go +++ /dev/null @@ -1,3955 +0,0 @@ -// Code generated by Kitex v0.11.0. DO NOT EDIT. - -package test - -import ( - "bytes" - "fmt" - "reflect" - "strings" - - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -// unused protection -var ( - _ = fmt.Formatter(nil) - _ = (*bytes.Buffer)(nil) - _ = (*strings.Builder)(nil) - _ = reflect.Type(nil) - _ = thrift.TProtocol(nil) - _ = bthrift.BinaryWriter(nil) -) - -func (p *Inner) FastRead(buf []byte) (int, error) { - var err error - var offset int - var l int - var fieldTypeId thrift.TType - var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - - for { - var isUnknownField bool - var beginOff int = offset - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField1(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.STRING { - l, err = p.FastReadField2(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField3(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 4: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField4(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 5: - if fieldTypeId == thrift.BYTE { - l, err = p.FastReadField5(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.DOUBLE { - l, err = p.FastReadField6(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - isUnknownField = true - } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - if isUnknownField { - p._unknownFields = append(p._unknownFields, buf[beginOff:offset]...) - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError - } - - return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Inner[fieldId]), err) -SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Inner) FastReadField1(buf []byte) (int, error) { - offset := 0 - - var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Num = _field - return offset, nil -} - -func (p *Inner) FastReadField2(buf []byte) (int, error) { - offset := 0 - - var _field *string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - _field = &v - - } - p.Desc = _field - return offset, nil -} - -func (p *Inner) FastReadField3(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[int64][]int64, size) - for i := 0; i < size; i++ { - var _key int64 - if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _val := make([]int64, 0, size) - for i := 0; i < size; i++ { - var _elem int64 - if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _val = append(_val, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.MapOfList = _field - return offset, nil -} - -func (p *Inner) FastReadField4(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[AEnum]int64, size) - for i := 0; i < size; i++ { - var _key AEnum - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = AEnum(v) - - } - - var _val int64 - if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _val = v - - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.MapOfEnumKey = _field - return offset, nil -} - -func (p *Inner) FastReadField5(buf []byte) (int, error) { - offset := 0 - - var _field *int8 - if v, l, err := bthrift.Binary.ReadByte(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - _field = &v - - } - p.Byte1 = _field - return offset, nil -} - -func (p *Inner) FastReadField6(buf []byte) (int, error) { - offset := 0 - - var _field *float64 - if v, l, err := bthrift.Binary.ReadDouble(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - _field = &v - - } - p.Double1 = _field - return offset, nil -} - -// for compatibility -func (p *Inner) FastWrite(buf []byte) int { - return 0 -} - -func (p *Inner) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Inner") - if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField5(buf[offset:], binaryWriter) - offset += p.fastWriteField6(buf[offset:], binaryWriter) - offset += p.fastWriteField2(buf[offset:], binaryWriter) - offset += p.fastWriteField3(buf[offset:], binaryWriter) - offset += p.fastWriteField4(buf[offset:], binaryWriter) - offset += copy(buf[offset:], p._unknownFields) - } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) - return offset -} - -func (p *Inner) BLength() int { - l := 0 - l += bthrift.Binary.StructBeginLength("Inner") - if p != nil { - l += p.field1Length() - l += p.field2Length() - l += p.field3Length() - l += p.field4Length() - l += p.field5Length() - l += p.field6Length() - l += len(p._unknownFields) - } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() - return l -} - -func (p *Inner) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetNum() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Num", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Num) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *Inner) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetDesc() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "desc", thrift.STRING, 2) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, *p.Desc) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *Inner) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetMapOfList() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "MapOfList", thrift.MAP, 3) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I64, thrift.LIST, 0) - var length int - for k, v := range p.MapOfList { - length++ - offset += bthrift.Binary.WriteI64(buf[offset:], k) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I64, 0) - var length int - for _, v := range v { - length++ - offset += bthrift.Binary.WriteI64(buf[offset:], v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I64, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I64, thrift.LIST, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *Inner) fastWriteField4(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetMapOfEnumKey() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "MapOfEnumKey", thrift.MAP, 4) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I32, thrift.I64, 0) - var length int - for k, v := range p.MapOfEnumKey { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - offset += bthrift.Binary.WriteI64(buf[offset:], v) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.I64, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *Inner) fastWriteField5(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetByte1() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Byte1", thrift.BYTE, 5) - offset += bthrift.Binary.WriteByte(buf[offset:], *p.Byte1) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *Inner) fastWriteField6(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetDouble1() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Double1", thrift.DOUBLE, 6) - offset += bthrift.Binary.WriteDouble(buf[offset:], *p.Double1) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *Inner) field1Length() int { - l := 0 - if p.IsSetNum() { - l += bthrift.Binary.FieldBeginLength("Num", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.Num) - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *Inner) field2Length() int { - l := 0 - if p.IsSetDesc() { - l += bthrift.Binary.FieldBeginLength("desc", thrift.STRING, 2) - l += bthrift.Binary.StringLengthNocopy(*p.Desc) - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *Inner) field3Length() int { - l := 0 - if p.IsSetMapOfList() { - l += bthrift.Binary.FieldBeginLength("MapOfList", thrift.MAP, 3) - l += bthrift.Binary.MapBeginLength(thrift.I64, thrift.LIST, len(p.MapOfList)) - for k, v := range p.MapOfList { - - l += bthrift.Binary.I64Length(k) - l += bthrift.Binary.ListBeginLength(thrift.I64, len(v)) - var tmpV int64 - l += bthrift.Binary.I64Length(int64(tmpV)) * len(v) - l += bthrift.Binary.ListEndLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *Inner) field4Length() int { - l := 0 - if p.IsSetMapOfEnumKey() { - l += bthrift.Binary.FieldBeginLength("MapOfEnumKey", thrift.MAP, 4) - l += bthrift.Binary.MapBeginLength(thrift.I32, thrift.I64, len(p.MapOfEnumKey)) - for k, v := range p.MapOfEnumKey { - - l += bthrift.Binary.I32Length(int32(k)) - l += bthrift.Binary.I64Length(v) - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *Inner) field5Length() int { - l := 0 - if p.IsSetByte1() { - l += bthrift.Binary.FieldBeginLength("Byte1", thrift.BYTE, 5) - l += bthrift.Binary.ByteLength(*p.Byte1) - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *Inner) field6Length() int { - l := 0 - if p.IsSetDouble1() { - l += bthrift.Binary.FieldBeginLength("Double1", thrift.DOUBLE, 6) - l += bthrift.Binary.DoubleLength(*p.Double1) - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *Local) FastRead(buf []byte) (int, error) { - var err error - var offset int - var l int - var fieldTypeId thrift.TType - var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - - for { - var isUnknownField bool - var beginOff int = offset - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField1(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - isUnknownField = true - } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - if isUnknownField { - p._unknownFields = append(p._unknownFields, buf[beginOff:offset]...) - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError - } - - return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Local[fieldId]), err) -SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Local) FastReadField1(buf []byte) (int, error) { - offset := 0 - - var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.L = _field - return offset, nil -} - -// for compatibility -func (p *Local) FastWrite(buf []byte) int { - return 0 -} - -func (p *Local) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Local") - if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += copy(buf[offset:], p._unknownFields) - } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) - return offset -} - -func (p *Local) BLength() int { - l := 0 - l += bthrift.Binary.StructBeginLength("Local") - if p != nil { - l += p.field1Length() - l += len(p._unknownFields) - } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() - return l -} - -func (p *Local) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "l", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.L) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *Local) field1Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("l", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.L) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) FastRead(buf []byte) (int, error) { - var err error - var offset int - var l int - var fieldTypeId thrift.TType - var fieldId int16 - var issetLeft bool = false - var issetRequiredIns bool = false - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - - for { - var isUnknownField bool - var beginOff int = offset - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField1(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - issetLeft = true - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField2(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.STRING { - l, err = p.FastReadField3(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 4: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField4(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 5: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField5(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.STRING { - l, err = p.FastReadField6(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 7: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField7(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 8: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField8(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 9: - if fieldTypeId == thrift.I64 { - l, err = p.FastReadField9(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 10: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField10(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 11: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField11(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 12: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField12(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 13: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField13(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 14: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField14(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - issetRequiredIns = true - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 16: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField16(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 17: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField17(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 18: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField18(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 19: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField19(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 20: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField20(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 21: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField21(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 22: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField22(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 23: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField23(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 24: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField24(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 25: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField25(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 26: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField26(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 27: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField27(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 28: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField28(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 29: - if fieldTypeId == thrift.SET { - l, err = p.FastReadField29(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 30: - if fieldTypeId == thrift.I16 { - l, err = p.FastReadField30(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 31: - if fieldTypeId == thrift.BOOL { - l, err = p.FastReadField31(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - isUnknownField = true - } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - if isUnknownField { - p._unknownFields = append(p._unknownFields, buf[beginOff:offset]...) - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError - } - - if !issetLeft { - fieldId = 1 - goto RequiredFieldNotSetError - } - - if !issetRequiredIns { - fieldId = 14 - goto RequiredFieldNotSetError - } - return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_FullStruct[fieldId]), err) -SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -RequiredFieldNotSetError: - return offset, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_FullStruct[fieldId])) -} - -func (p *FullStruct) FastReadField1(buf []byte) (int, error) { - offset := 0 - - var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Left = _field - return offset, nil -} - -func (p *FullStruct) FastReadField2(buf []byte) (int, error) { - offset := 0 - - var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Right = _field - return offset, nil -} - -func (p *FullStruct) FastReadField3(buf []byte) (int, error) { - offset := 0 - - var _field []byte - if v, l, err := bthrift.Binary.ReadBinary(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = []byte(v) - - } - p.Dummy = _field - return offset, nil -} - -func (p *FullStruct) FastReadField4(buf []byte) (int, error) { - offset := 0 - _field := NewInner() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.InnerReq = _field - return offset, nil -} - -func (p *FullStruct) FastReadField5(buf []byte) (int, error) { - offset := 0 - - var _field HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = HTTPStatus(v) - - } - p.Status = _field - return offset, nil -} - -func (p *FullStruct) FastReadField6(buf []byte) (int, error) { - offset := 0 - - var _field string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Str = _field - return offset, nil -} - -func (p *FullStruct) FastReadField7(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]HTTPStatus, 0, size) - for i := 0; i < size; i++ { - var _elem HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = HTTPStatus(v) - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.EnumList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField8(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[int32]string, size) - for i := 0; i < size; i++ { - var _key int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - var _val string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _val = v - - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.Strmap = _field - return offset, nil -} - -func (p *FullStruct) FastReadField9(buf []byte) (int, error) { - offset := 0 - - var _field int64 - if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Int64 = _field - return offset, nil -} - -func (p *FullStruct) FastReadField10(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.IntList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField11(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]*Local, 0, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - _elem := &values[i] - _elem.InitDefault() - if l, err := _elem.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.LocalList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField12(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[string]*Local, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - var _key string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - _val := &values[i] - _val.InitDefault() - if l, err := _val.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.StrLocalMap = _field - return offset, nil -} - -func (p *FullStruct) FastReadField13(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([][]int32, 0, size) - for i := 0; i < size; i++ { - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _elem := make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem1 int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem1 = v - - } - - _elem = append(_elem, _elem1) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.NestList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField14(buf []byte) (int, error) { - offset := 0 - _field := NewLocal() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.RequiredIns = _field - return offset, nil -} - -func (p *FullStruct) FastReadField16(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[string][]string, size) - for i := 0; i < size; i++ { - var _key string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _val := make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _val = append(_val, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.NestMap = _field - return offset, nil -} - -func (p *FullStruct) FastReadField17(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]map[string]HTTPStatus, 0, size) - for i := 0; i < size; i++ { - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _elem := make(map[string]HTTPStatus, size) - for i := 0; i < size; i++ { - var _key string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - var _val HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _val = HTTPStatus(v) - - } - - _elem[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.NestMap2 = _field - return offset, nil -} - -func (p *FullStruct) FastReadField18(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[int32]HTTPStatus, size) - for i := 0; i < size; i++ { - var _key int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - var _val HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _val = HTTPStatus(v) - - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.EnumMap = _field - return offset, nil -} - -func (p *FullStruct) FastReadField19(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.Strlist = _field - return offset, nil -} - -func (p *FullStruct) FastReadField20(buf []byte) (int, error) { - offset := 0 - _field := NewLocal() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.OptionalIns = _field - return offset, nil -} - -func (p *FullStruct) FastReadField21(buf []byte) (int, error) { - offset := 0 - _field := NewInner() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.AnotherInner = _field - return offset, nil -} - -func (p *FullStruct) FastReadField22(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.OptNilList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField23(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]string, 0, size) - for i := 0; i < size; i++ { - var _elem string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.NilList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField24(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]*Inner, 0, size) - values := make([]Inner, size) - for i := 0; i < size; i++ { - _elem := &values[i] - _elem.InitDefault() - if l, err := _elem.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.OptNilInsList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField25(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]*Inner, 0, size) - values := make([]Inner, size) - for i := 0; i < size; i++ { - _elem := &values[i] - _elem.InitDefault() - if l, err := _elem.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.NilInsList = _field - return offset, nil -} - -func (p *FullStruct) FastReadField26(buf []byte) (int, error) { - offset := 0 - - var _field *HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - tmp := HTTPStatus(v) - _field = &tmp - - } - p.OptStatus = _field - return offset, nil -} - -func (p *FullStruct) FastReadField27(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[HTTPStatus]*Local, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - var _key HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = HTTPStatus(v) - - } - - _val := &values[i] - _val.InitDefault() - if l, err := _val.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.EnumKeyMap = _field - return offset, nil -} - -func (p *FullStruct) FastReadField28(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[HTTPStatus][]map[string]*Local, size) - for i := 0; i < size; i++ { - var _key HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = HTTPStatus(v) - - } - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _val := make([]map[string]*Local, 0, size) - for i := 0; i < size; i++ { - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _elem := make(map[string]*Local, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - var _key1 string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key1 = v - - } - - _val1 := &values[i] - _val1.InitDefault() - if l, err := _val1.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _elem[_key1] = _val1 - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _val = append(_val, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.Complex = _field - return offset, nil -} - -func (p *FullStruct) FastReadField29(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadSetBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]int64, 0, size) - for i := 0; i < size; i++ { - var _elem int64 - if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadSetEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.I64Set = _field - return offset, nil -} - -func (p *FullStruct) FastReadField30(buf []byte) (int, error) { - offset := 0 - - var _field int16 - if v, l, err := bthrift.Binary.ReadI16(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Int16 = _field - return offset, nil -} - -func (p *FullStruct) FastReadField31(buf []byte) (int, error) { - offset := 0 - - var _field bool - if v, l, err := bthrift.Binary.ReadBool(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.IsSet = _field - return offset, nil -} - -// for compatibility -func (p *FullStruct) FastWrite(buf []byte) int { - return 0 -} - -func (p *FullStruct) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "FullStruct") - if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField2(buf[offset:], binaryWriter) - offset += p.fastWriteField9(buf[offset:], binaryWriter) - offset += p.fastWriteField30(buf[offset:], binaryWriter) - offset += p.fastWriteField31(buf[offset:], binaryWriter) - offset += p.fastWriteField3(buf[offset:], binaryWriter) - offset += p.fastWriteField4(buf[offset:], binaryWriter) - offset += p.fastWriteField5(buf[offset:], binaryWriter) - offset += p.fastWriteField6(buf[offset:], binaryWriter) - offset += p.fastWriteField7(buf[offset:], binaryWriter) - offset += p.fastWriteField8(buf[offset:], binaryWriter) - offset += p.fastWriteField10(buf[offset:], binaryWriter) - offset += p.fastWriteField11(buf[offset:], binaryWriter) - offset += p.fastWriteField12(buf[offset:], binaryWriter) - offset += p.fastWriteField13(buf[offset:], binaryWriter) - offset += p.fastWriteField14(buf[offset:], binaryWriter) - offset += p.fastWriteField16(buf[offset:], binaryWriter) - offset += p.fastWriteField17(buf[offset:], binaryWriter) - offset += p.fastWriteField18(buf[offset:], binaryWriter) - offset += p.fastWriteField19(buf[offset:], binaryWriter) - offset += p.fastWriteField20(buf[offset:], binaryWriter) - offset += p.fastWriteField21(buf[offset:], binaryWriter) - offset += p.fastWriteField22(buf[offset:], binaryWriter) - offset += p.fastWriteField23(buf[offset:], binaryWriter) - offset += p.fastWriteField24(buf[offset:], binaryWriter) - offset += p.fastWriteField25(buf[offset:], binaryWriter) - offset += p.fastWriteField26(buf[offset:], binaryWriter) - offset += p.fastWriteField27(buf[offset:], binaryWriter) - offset += p.fastWriteField28(buf[offset:], binaryWriter) - offset += p.fastWriteField29(buf[offset:], binaryWriter) - offset += copy(buf[offset:], p._unknownFields) - } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) BLength() int { - l := 0 - l += bthrift.Binary.StructBeginLength("FullStruct") - if p != nil { - l += p.field1Length() - l += p.field2Length() - l += p.field3Length() - l += p.field4Length() - l += p.field5Length() - l += p.field6Length() - l += p.field7Length() - l += p.field8Length() - l += p.field9Length() - l += p.field10Length() - l += p.field11Length() - l += p.field12Length() - l += p.field13Length() - l += p.field14Length() - l += p.field16Length() - l += p.field17Length() - l += p.field18Length() - l += p.field19Length() - l += p.field20Length() - l += p.field21Length() - l += p.field22Length() - l += p.field23Length() - l += p.field24Length() - l += p.field25Length() - l += p.field26Length() - l += p.field27Length() - l += p.field28Length() - l += p.field29Length() - l += p.field30Length() - l += p.field31Length() - l += len(p._unknownFields) - } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() - return l -} - -func (p *FullStruct) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Left", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Left) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetRight() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Right", thrift.I32, 2) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Right) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Dummy", thrift.STRING, 3) - offset += bthrift.Binary.WriteBinaryNocopy(buf[offset:], binaryWriter, []byte(p.Dummy)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField4(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "InnerReq", thrift.STRUCT, 4) - offset += p.InnerReq.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField5(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "status", thrift.I32, 5) - offset += bthrift.Binary.WriteI32(buf[offset:], int32(p.Status)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField6(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Str", thrift.STRING, 6) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Str) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField7(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "enum_list", thrift.LIST, 7) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I32, 0) - var length int - for _, v := range p.EnumList { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField8(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetStrmap() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Strmap", thrift.MAP, 8) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I32, thrift.STRING, 0) - var length int - for k, v := range p.Strmap { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], k) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.STRING, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField9(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Int64", thrift.I64, 9) - offset += bthrift.Binary.WriteI64(buf[offset:], p.Int64) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField10(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetIntList() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "IntList", thrift.LIST, 10) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I32, 0) - var length int - for _, v := range p.IntList { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField11(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "localList", thrift.LIST, 11) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRUCT, 0) - var length int - for _, v := range p.LocalList { - length++ - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField12(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "StrLocalMap", thrift.MAP, 12) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, 0) - var length int - for k, v := range p.StrLocalMap { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField13(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "nestList", thrift.LIST, 13) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.LIST, 0) - var length int - for _, v := range p.NestList { - length++ - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I32, 0) - var length int - for _, v := range v { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.LIST, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField14(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "required_ins", thrift.STRUCT, 14) - offset += p.RequiredIns.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField16(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "nestMap", thrift.MAP, 16) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.LIST, 0) - var length int - for k, v := range p.NestMap { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRING, 0) - var length int - for _, v := range v { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.LIST, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField17(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "nestMap2", thrift.LIST, 17) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.MAP, 0) - var length int - for _, v := range p.NestMap2 { - length++ - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.I32, 0) - var length int - for k, v := range v { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.I32, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.MAP, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField18(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "enum_map", thrift.MAP, 18) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I32, thrift.I32, 0) - var length int - for k, v := range p.EnumMap { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], k) - offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.I32, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField19(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Strlist", thrift.LIST, 19) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRING, 0) - var length int - for _, v := range p.Strlist { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField20(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetOptionalIns() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "optional_ins", thrift.STRUCT, 20) - offset += p.OptionalIns.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField21(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "AnotherInner", thrift.STRUCT, 21) - offset += p.AnotherInner.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField22(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetOptNilList() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "opt_nil_list", thrift.LIST, 22) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRING, 0) - var length int - for _, v := range p.OptNilList { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField23(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "nil_list", thrift.LIST, 23) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRING, 0) - var length int - for _, v := range p.NilList { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField24(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetOptNilInsList() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "opt_nil_ins_list", thrift.LIST, 24) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRUCT, 0) - var length int - for _, v := range p.OptNilInsList { - length++ - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField25(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "nil_ins_list", thrift.LIST, 25) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRUCT, 0) - var length int - for _, v := range p.NilInsList { - length++ - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField26(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetOptStatus() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "opt_status", thrift.I32, 26) - offset += bthrift.Binary.WriteI32(buf[offset:], int32(*p.OptStatus)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *FullStruct) fastWriteField27(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "enum_key_map", thrift.MAP, 27) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I32, thrift.STRUCT, 0) - var length int - for k, v := range p.EnumKeyMap { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.STRUCT, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField28(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "complex", thrift.MAP, 28) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I32, thrift.LIST, 0) - var length int - for k, v := range p.Complex { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.MAP, 0) - var length int - for _, v := range v { - length++ - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, 0) - var length int - for k, v := range v { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.MAP, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.LIST, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField29(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "i64Set", thrift.SET, 29) - setBeginOffset := offset - offset += bthrift.Binary.SetBeginLength(thrift.I64, 0) - - for i := 0; i < len(p.I64Set); i++ { - for j := i + 1; j < len(p.I64Set); j++ { - if func(tgt, src int64) bool { - if tgt != src { - return false - } - return true - }(p.I64Set[i], p.I64Set[j]) { - panic(fmt.Errorf("%T error writing set field: slice is not unique", p.I64Set[i])) - } - } - } - var length int - for _, v := range p.I64Set { - length++ - offset += bthrift.Binary.WriteI64(buf[offset:], v) - } - bthrift.Binary.WriteSetBegin(buf[setBeginOffset:], thrift.I64, length) - offset += bthrift.Binary.WriteSetEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField30(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Int16", thrift.I16, 30) - offset += bthrift.Binary.WriteI16(buf[offset:], p.Int16) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) fastWriteField31(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "isSet", thrift.BOOL, 31) - offset += bthrift.Binary.WriteBool(buf[offset:], p.IsSet) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *FullStruct) field1Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Left", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.Left) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field2Length() int { - l := 0 - if p.IsSetRight() { - l += bthrift.Binary.FieldBeginLength("Right", thrift.I32, 2) - l += bthrift.Binary.I32Length(p.Right) - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field3Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Dummy", thrift.STRING, 3) - l += bthrift.Binary.BinaryLengthNocopy([]byte(p.Dummy)) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field4Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("InnerReq", thrift.STRUCT, 4) - l += p.InnerReq.BLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field5Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("status", thrift.I32, 5) - l += bthrift.Binary.I32Length(int32(p.Status)) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field6Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Str", thrift.STRING, 6) - l += bthrift.Binary.StringLengthNocopy(p.Str) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field7Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("enum_list", thrift.LIST, 7) - l += bthrift.Binary.ListBeginLength(thrift.I32, len(p.EnumList)) - for _, v := range p.EnumList { - l += bthrift.Binary.I32Length(int32(v)) - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field8Length() int { - l := 0 - if p.IsSetStrmap() { - l += bthrift.Binary.FieldBeginLength("Strmap", thrift.MAP, 8) - l += bthrift.Binary.MapBeginLength(thrift.I32, thrift.STRING, len(p.Strmap)) - for k, v := range p.Strmap { - - l += bthrift.Binary.I32Length(k) - l += bthrift.Binary.StringLengthNocopy(v) - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field9Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Int64", thrift.I64, 9) - l += bthrift.Binary.I64Length(p.Int64) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field10Length() int { - l := 0 - if p.IsSetIntList() { - l += bthrift.Binary.FieldBeginLength("IntList", thrift.LIST, 10) - l += bthrift.Binary.ListBeginLength(thrift.I32, len(p.IntList)) - var tmpV int32 - l += bthrift.Binary.I32Length(int32(tmpV)) * len(p.IntList) - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field11Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("localList", thrift.LIST, 11) - l += bthrift.Binary.ListBeginLength(thrift.STRUCT, len(p.LocalList)) - for _, v := range p.LocalList { - l += v.BLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field12Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("StrLocalMap", thrift.MAP, 12) - l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, len(p.StrLocalMap)) - for k, v := range p.StrLocalMap { - - l += bthrift.Binary.StringLengthNocopy(k) - l += v.BLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field13Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("nestList", thrift.LIST, 13) - l += bthrift.Binary.ListBeginLength(thrift.LIST, len(p.NestList)) - for _, v := range p.NestList { - l += bthrift.Binary.ListBeginLength(thrift.I32, len(v)) - var tmpV int32 - l += bthrift.Binary.I32Length(int32(tmpV)) * len(v) - l += bthrift.Binary.ListEndLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field14Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("required_ins", thrift.STRUCT, 14) - l += p.RequiredIns.BLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field16Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("nestMap", thrift.MAP, 16) - l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.LIST, len(p.NestMap)) - for k, v := range p.NestMap { - - l += bthrift.Binary.StringLengthNocopy(k) - l += bthrift.Binary.ListBeginLength(thrift.STRING, len(v)) - for _, v := range v { - l += bthrift.Binary.StringLengthNocopy(v) - } - l += bthrift.Binary.ListEndLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field17Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("nestMap2", thrift.LIST, 17) - l += bthrift.Binary.ListBeginLength(thrift.MAP, len(p.NestMap2)) - for _, v := range p.NestMap2 { - l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.I32, len(v)) - for k, v := range v { - - l += bthrift.Binary.StringLengthNocopy(k) - l += bthrift.Binary.I32Length(int32(v)) - } - l += bthrift.Binary.MapEndLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field18Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("enum_map", thrift.MAP, 18) - l += bthrift.Binary.MapBeginLength(thrift.I32, thrift.I32, len(p.EnumMap)) - for k, v := range p.EnumMap { - - l += bthrift.Binary.I32Length(k) - l += bthrift.Binary.I32Length(int32(v)) - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field19Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Strlist", thrift.LIST, 19) - l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.Strlist)) - for _, v := range p.Strlist { - l += bthrift.Binary.StringLengthNocopy(v) - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field20Length() int { - l := 0 - if p.IsSetOptionalIns() { - l += bthrift.Binary.FieldBeginLength("optional_ins", thrift.STRUCT, 20) - l += p.OptionalIns.BLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field21Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("AnotherInner", thrift.STRUCT, 21) - l += p.AnotherInner.BLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field22Length() int { - l := 0 - if p.IsSetOptNilList() { - l += bthrift.Binary.FieldBeginLength("opt_nil_list", thrift.LIST, 22) - l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.OptNilList)) - for _, v := range p.OptNilList { - l += bthrift.Binary.StringLengthNocopy(v) - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field23Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("nil_list", thrift.LIST, 23) - l += bthrift.Binary.ListBeginLength(thrift.STRING, len(p.NilList)) - for _, v := range p.NilList { - l += bthrift.Binary.StringLengthNocopy(v) - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field24Length() int { - l := 0 - if p.IsSetOptNilInsList() { - l += bthrift.Binary.FieldBeginLength("opt_nil_ins_list", thrift.LIST, 24) - l += bthrift.Binary.ListBeginLength(thrift.STRUCT, len(p.OptNilInsList)) - for _, v := range p.OptNilInsList { - l += v.BLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field25Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("nil_ins_list", thrift.LIST, 25) - l += bthrift.Binary.ListBeginLength(thrift.STRUCT, len(p.NilInsList)) - for _, v := range p.NilInsList { - l += v.BLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field26Length() int { - l := 0 - if p.IsSetOptStatus() { - l += bthrift.Binary.FieldBeginLength("opt_status", thrift.I32, 26) - l += bthrift.Binary.I32Length(int32(*p.OptStatus)) - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *FullStruct) field27Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("enum_key_map", thrift.MAP, 27) - l += bthrift.Binary.MapBeginLength(thrift.I32, thrift.STRUCT, len(p.EnumKeyMap)) - for k, v := range p.EnumKeyMap { - - l += bthrift.Binary.I32Length(int32(k)) - l += v.BLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field28Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("complex", thrift.MAP, 28) - l += bthrift.Binary.MapBeginLength(thrift.I32, thrift.LIST, len(p.Complex)) - for k, v := range p.Complex { - - l += bthrift.Binary.I32Length(int32(k)) - l += bthrift.Binary.ListBeginLength(thrift.MAP, len(v)) - for _, v := range v { - l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, len(v)) - for k, v := range v { - - l += bthrift.Binary.StringLengthNocopy(k) - l += v.BLength() - } - l += bthrift.Binary.MapEndLength() - } - l += bthrift.Binary.ListEndLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field29Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("i64Set", thrift.SET, 29) - l += bthrift.Binary.SetBeginLength(thrift.I64, len(p.I64Set)) - - for i := 0; i < len(p.I64Set); i++ { - for j := i + 1; j < len(p.I64Set); j++ { - if func(tgt, src int64) bool { - if tgt != src { - return false - } - return true - }(p.I64Set[i], p.I64Set[j]) { - panic(fmt.Errorf("%T error writing set field: slice is not unique", p.I64Set[i])) - } - } - } - var tmpV int64 - l += bthrift.Binary.I64Length(int64(tmpV)) * len(p.I64Set) - l += bthrift.Binary.SetEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field30Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Int16", thrift.I16, 30) - l += bthrift.Binary.I16Length(p.Int16) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *FullStruct) field31Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("isSet", thrift.BOOL, 31) - l += bthrift.Binary.BoolLength(p.IsSet) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) FastRead(buf []byte) (int, error) { - var err error - var offset int - var l int - var fieldTypeId thrift.TType - var fieldId int16 - var issetLeft bool = false - var issetRequiredIns bool = false - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - - for { - var isUnknownField bool - var beginOff int = offset - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - switch fieldId { - case 1: - if fieldTypeId == thrift.I32 { - l, err = p.FastReadField1(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - issetLeft = true - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.STRING { - l, err = p.FastReadField3(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.STRING { - l, err = p.FastReadField6(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 7: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField7(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 9: - if fieldTypeId == thrift.I64 { - l, err = p.FastReadField9(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 10: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField10(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 11: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField11(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 12: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField12(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 13: - if fieldTypeId == thrift.LIST { - l, err = p.FastReadField13(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 14: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField14(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - issetRequiredIns = true - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 20: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField20(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 21: - if fieldTypeId == thrift.STRUCT { - l, err = p.FastReadField21(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - case 27: - if fieldTypeId == thrift.MAP { - l, err = p.FastReadField27(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldError - } - } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - } - default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - isUnknownField = true - } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - if isUnknownField { - p._unknownFields = append(p._unknownFields, buf[beginOff:offset]...) - } - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError - } - - if !issetLeft { - fieldId = 1 - goto RequiredFieldNotSetError - } - - if !issetRequiredIns { - fieldId = 14 - goto RequiredFieldNotSetError - } - return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MixedStruct[fieldId]), err) -SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -RequiredFieldNotSetError: - return offset, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_MixedStruct[fieldId])) -} - -func (p *MixedStruct) FastReadField1(buf []byte) (int, error) { - offset := 0 - - var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Left = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField3(buf []byte) (int, error) { - offset := 0 - - var _field []byte - if v, l, err := bthrift.Binary.ReadBinary(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = []byte(v) - - } - p.Dummy = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField6(buf []byte) (int, error) { - offset := 0 - - var _field string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Str = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField7(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]HTTPStatus, 0, size) - for i := 0; i < size; i++ { - var _elem HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = HTTPStatus(v) - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.EnumList = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField9(buf []byte) (int, error) { - offset := 0 - - var _field int64 - if v, l, err := bthrift.Binary.ReadI64(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _field = v - - } - p.Int64 = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField10(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem = v - - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.IntList = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField11(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([]*Local, 0, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - _elem := &values[i] - _elem.InitDefault() - if l, err := _elem.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.LocalList = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField12(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[string]*Local, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - var _key string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = v - - } - - _val := &values[i] - _val.InitDefault() - if l, err := _val.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.StrLocalMap = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField13(buf []byte) (int, error) { - offset := 0 - - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make([][]int32, 0, size) - for i := 0; i < size; i++ { - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _elem := make([]int32, 0, size) - for i := 0; i < size; i++ { - var _elem1 int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _elem1 = v - - } - - _elem = append(_elem, _elem1) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field = append(_field, _elem) - } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.NestList = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField14(buf []byte) (int, error) { - offset := 0 - _field := NewLocal() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.RequiredIns = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField20(buf []byte) (int, error) { - offset := 0 - _field := NewLocal() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.OptionalIns = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField21(buf []byte) (int, error) { - offset := 0 - _field := NewInner() - if l, err := _field.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.AnotherInner = _field - return offset, nil -} - -func (p *MixedStruct) FastReadField27(buf []byte) (int, error) { - offset := 0 - - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) - offset += l - if err != nil { - return offset, err - } - _field := make(map[HTTPStatus]*Local, size) - values := make([]Local, size) - for i := 0; i < size; i++ { - var _key HTTPStatus - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - - _key = HTTPStatus(v) - - } - - _val := &values[i] - _val.InitDefault() - if l, err := _val.FastRead(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - - _field[_key] = _val - } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } - p.EnumKeyMap = _field - return offset, nil -} - -// for compatibility -func (p *MixedStruct) FastWrite(buf []byte) int { - return 0 -} - -func (p *MixedStruct) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "MixedStruct") - if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField9(buf[offset:], binaryWriter) - offset += p.fastWriteField3(buf[offset:], binaryWriter) - offset += p.fastWriteField6(buf[offset:], binaryWriter) - offset += p.fastWriteField7(buf[offset:], binaryWriter) - offset += p.fastWriteField10(buf[offset:], binaryWriter) - offset += p.fastWriteField11(buf[offset:], binaryWriter) - offset += p.fastWriteField12(buf[offset:], binaryWriter) - offset += p.fastWriteField13(buf[offset:], binaryWriter) - offset += p.fastWriteField14(buf[offset:], binaryWriter) - offset += p.fastWriteField20(buf[offset:], binaryWriter) - offset += p.fastWriteField21(buf[offset:], binaryWriter) - offset += p.fastWriteField27(buf[offset:], binaryWriter) - offset += copy(buf[offset:], p._unknownFields) - } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) BLength() int { - l := 0 - l += bthrift.Binary.StructBeginLength("MixedStruct") - if p != nil { - l += p.field1Length() - l += p.field3Length() - l += p.field6Length() - l += p.field7Length() - l += p.field9Length() - l += p.field10Length() - l += p.field11Length() - l += p.field12Length() - l += p.field13Length() - l += p.field14Length() - l += p.field20Length() - l += p.field21Length() - l += p.field27Length() - l += len(p._unknownFields) - } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() - return l -} - -func (p *MixedStruct) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Left", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Left) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField3(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Dummy", thrift.STRING, 3) - offset += bthrift.Binary.WriteBinaryNocopy(buf[offset:], binaryWriter, []byte(p.Dummy)) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField6(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Str", thrift.STRING, 6) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Str) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField7(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "enum_list", thrift.LIST, 7) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I32, 0) - var length int - for _, v := range p.EnumList { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(v)) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField9(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Int64", thrift.I64, 9) - offset += bthrift.Binary.WriteI64(buf[offset:], p.Int64) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField10(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetIntList() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "IntList", thrift.LIST, 10) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I32, 0) - var length int - for _, v := range p.IntList { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *MixedStruct) fastWriteField11(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "localList", thrift.LIST, 11) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.STRUCT, 0) - var length int - for _, v := range p.LocalList { - length++ - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField12(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "StrLocalMap", thrift.MAP, 12) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, 0) - var length int - for k, v := range p.StrLocalMap { - length++ - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, k) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.STRING, thrift.STRUCT, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField13(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "nestList", thrift.LIST, 13) - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.LIST, 0) - var length int - for _, v := range p.NestList { - length++ - listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift.I32, 0) - var length int - for _, v := range v { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], v) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I32, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.LIST, length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField14(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "required_ins", thrift.STRUCT, 14) - offset += p.RequiredIns.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField20(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - if p.IsSetOptionalIns() { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "optional_ins", thrift.STRUCT, 20) - offset += p.OptionalIns.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - } - return offset -} - -func (p *MixedStruct) fastWriteField21(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "AnotherInner", thrift.STRUCT, 21) - offset += p.AnotherInner.FastWriteNocopy(buf[offset:], binaryWriter) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) fastWriteField27(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "enum_key_map", thrift.MAP, 27) - mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift.I32, thrift.STRUCT, 0) - var length int - for k, v := range p.EnumKeyMap { - length++ - offset += bthrift.Binary.WriteI32(buf[offset:], int32(k)) - offset += v.FastWriteNocopy(buf[offset:], binaryWriter) - } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift.I32, thrift.STRUCT, length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) - return offset -} - -func (p *MixedStruct) field1Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Left", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.Left) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field3Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Dummy", thrift.STRING, 3) - l += bthrift.Binary.BinaryLengthNocopy([]byte(p.Dummy)) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field6Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Str", thrift.STRING, 6) - l += bthrift.Binary.StringLengthNocopy(p.Str) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field7Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("enum_list", thrift.LIST, 7) - l += bthrift.Binary.ListBeginLength(thrift.I32, len(p.EnumList)) - for _, v := range p.EnumList { - l += bthrift.Binary.I32Length(int32(v)) - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field9Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("Int64", thrift.I64, 9) - l += bthrift.Binary.I64Length(p.Int64) - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field10Length() int { - l := 0 - if p.IsSetIntList() { - l += bthrift.Binary.FieldBeginLength("IntList", thrift.LIST, 10) - l += bthrift.Binary.ListBeginLength(thrift.I32, len(p.IntList)) - var tmpV int32 - l += bthrift.Binary.I32Length(int32(tmpV)) * len(p.IntList) - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *MixedStruct) field11Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("localList", thrift.LIST, 11) - l += bthrift.Binary.ListBeginLength(thrift.STRUCT, len(p.LocalList)) - for _, v := range p.LocalList { - l += v.BLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field12Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("StrLocalMap", thrift.MAP, 12) - l += bthrift.Binary.MapBeginLength(thrift.STRING, thrift.STRUCT, len(p.StrLocalMap)) - for k, v := range p.StrLocalMap { - - l += bthrift.Binary.StringLengthNocopy(k) - l += v.BLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field13Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("nestList", thrift.LIST, 13) - l += bthrift.Binary.ListBeginLength(thrift.LIST, len(p.NestList)) - for _, v := range p.NestList { - l += bthrift.Binary.ListBeginLength(thrift.I32, len(v)) - var tmpV int32 - l += bthrift.Binary.I32Length(int32(tmpV)) * len(v) - l += bthrift.Binary.ListEndLength() - } - l += bthrift.Binary.ListEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field14Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("required_ins", thrift.STRUCT, 14) - l += p.RequiredIns.BLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field20Length() int { - l := 0 - if p.IsSetOptionalIns() { - l += bthrift.Binary.FieldBeginLength("optional_ins", thrift.STRUCT, 20) - l += p.OptionalIns.BLength() - l += bthrift.Binary.FieldEndLength() - } - return l -} - -func (p *MixedStruct) field21Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("AnotherInner", thrift.STRUCT, 21) - l += p.AnotherInner.BLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *MixedStruct) field27Length() int { - l := 0 - l += bthrift.Binary.FieldBeginLength("enum_key_map", thrift.MAP, 27) - l += bthrift.Binary.MapBeginLength(thrift.I32, thrift.STRUCT, len(p.EnumKeyMap)) - for k, v := range p.EnumKeyMap { - - l += bthrift.Binary.I32Length(int32(k)) - l += v.BLength() - } - l += bthrift.Binary.MapEndLength() - l += bthrift.Binary.FieldEndLength() - return l -} - -func (p *EmptyStruct) FastRead(buf []byte) (int, error) { - var err error - var offset int - var l int - var fieldTypeId thrift.TType - var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - - for { - var beginOff int = offset - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) - offset += l - if err != nil { - goto SkipFieldError - } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - p._unknownFields = append(p._unknownFields, buf[beginOff:offset]...) - } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError - } - - return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -// for compatibility -func (p *EmptyStruct) FastWrite(buf []byte) int { - return 0 -} - -func (p *EmptyStruct) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { - offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "EmptyStruct") - if p != nil { - offset += copy(buf[offset:], p._unknownFields) - } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) - return offset -} - -func (p *EmptyStruct) BLength() int { - l := 0 - l += bthrift.Binary.StructBeginLength("EmptyStruct") - if p != nil { - l += len(p._unknownFields) - } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() - return l -} diff --git a/pkg/protocol/bthrift/test/kitex_gen/test/test.go b/pkg/protocol/bthrift/test/kitex_gen/test/test.go deleted file mode 100644 index 25355a9b17..0000000000 --- a/pkg/protocol/bthrift/test/kitex_gen/test/test.go +++ /dev/null @@ -1,1619 +0,0 @@ -// Code generated by thriftgo (0.3.13). DO NOT EDIT. - -package test - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "fmt" - "github.com/cloudwego/thriftgo/generator/golang/extension/unknown" - "strings" -) - -type AEnum int64 - -const ( - AEnum_A AEnum = 1 - AEnum_B AEnum = 2 -) - -func (p AEnum) String() string { - switch p { - case AEnum_A: - return "A" - case AEnum_B: - return "B" - } - return "" -} - -func AEnumFromString(s string) (AEnum, error) { - switch s { - case "A": - return AEnum_A, nil - case "B": - return AEnum_B, nil - } - return AEnum(0), fmt.Errorf("not a valid AEnum string") -} - -func AEnumPtr(v AEnum) *AEnum { return &v } -func (p *AEnum) Scan(value interface{}) (err error) { - var result sql.NullInt64 - err = result.Scan(value) - *p = AEnum(result.Int64) - return -} - -func (p *AEnum) Value() (driver.Value, error) { - if p == nil { - return nil, nil - } - return int64(*p), nil -} - -type HTTPStatus int64 - -const ( - HTTPStatus_OK HTTPStatus = 200 - HTTPStatus_NOT_FOUND HTTPStatus = 404 -) - -func (p HTTPStatus) String() string { - switch p { - case HTTPStatus_OK: - return "OK" - case HTTPStatus_NOT_FOUND: - return "NOT_FOUND" - } - return "" -} - -func HTTPStatusFromString(s string) (HTTPStatus, error) { - switch s { - case "OK": - return HTTPStatus_OK, nil - case "NOT_FOUND": - return HTTPStatus_NOT_FOUND, nil - } - return HTTPStatus(0), fmt.Errorf("not a valid HTTPStatus string") -} - -func HTTPStatusPtr(v HTTPStatus) *HTTPStatus { return &v } -func (p *HTTPStatus) Scan(value interface{}) (err error) { - var result sql.NullInt64 - err = result.Scan(value) - *p = HTTPStatus(result.Int64) - return -} - -func (p *HTTPStatus) Value() (driver.Value, error) { - if p == nil { - return nil, nil - } - return int64(*p), nil -} - -type Inner struct { - Num int32 `thrift:"Num,1,optional" frugal:"1,optional,i32" json:"Num,omitempty"` - Desc *string `thrift:"desc,2,optional" frugal:"2,optional,string" json:"desc,omitempty"` - MapOfList map[int64][]int64 `thrift:"MapOfList,3,optional" frugal:"3,optional,map>" json:"MapOfList,omitempty"` - MapOfEnumKey map[AEnum]int64 `thrift:"MapOfEnumKey,4,optional" frugal:"4,optional,map" json:"MapOfEnumKey,omitempty"` - Byte1 *int8 `thrift:"Byte1,5,optional" frugal:"5,optional,byte" json:"Byte1,omitempty"` - Double1 *float64 `thrift:"Double1,6,optional" frugal:"6,optional,double" json:"Double1,omitempty"` - _unknownFields unknown.Fields -} - -func NewInner() *Inner { - return &Inner{ - - Num: 5, - } -} - -func (p *Inner) InitDefault() { - p.Num = 5 -} - -var Inner_Num_DEFAULT int32 = 5 - -func (p *Inner) GetNum() (v int32) { - if !p.IsSetNum() { - return Inner_Num_DEFAULT - } - return p.Num -} - -var Inner_Desc_DEFAULT string - -func (p *Inner) GetDesc() (v string) { - if !p.IsSetDesc() { - return Inner_Desc_DEFAULT - } - return *p.Desc -} - -var Inner_MapOfList_DEFAULT map[int64][]int64 - -func (p *Inner) GetMapOfList() (v map[int64][]int64) { - if !p.IsSetMapOfList() { - return Inner_MapOfList_DEFAULT - } - return p.MapOfList -} - -var Inner_MapOfEnumKey_DEFAULT map[AEnum]int64 - -func (p *Inner) GetMapOfEnumKey() (v map[AEnum]int64) { - if !p.IsSetMapOfEnumKey() { - return Inner_MapOfEnumKey_DEFAULT - } - return p.MapOfEnumKey -} - -var Inner_Byte1_DEFAULT int8 - -func (p *Inner) GetByte1() (v int8) { - if !p.IsSetByte1() { - return Inner_Byte1_DEFAULT - } - return *p.Byte1 -} - -var Inner_Double1_DEFAULT float64 - -func (p *Inner) GetDouble1() (v float64) { - if !p.IsSetDouble1() { - return Inner_Double1_DEFAULT - } - return *p.Double1 -} -func (p *Inner) SetNum(val int32) { - p.Num = val -} -func (p *Inner) SetDesc(val *string) { - p.Desc = val -} -func (p *Inner) SetMapOfList(val map[int64][]int64) { - p.MapOfList = val -} -func (p *Inner) SetMapOfEnumKey(val map[AEnum]int64) { - p.MapOfEnumKey = val -} -func (p *Inner) SetByte1(val *int8) { - p.Byte1 = val -} -func (p *Inner) SetDouble1(val *float64) { - p.Double1 = val -} - -func (p *Inner) CarryingUnknownFields() bool { - return len(p._unknownFields) > 0 -} - -func (p *Inner) IsSetNum() bool { - return p.Num != Inner_Num_DEFAULT -} - -func (p *Inner) IsSetDesc() bool { - return p.Desc != nil -} - -func (p *Inner) IsSetMapOfList() bool { - return p.MapOfList != nil -} - -func (p *Inner) IsSetMapOfEnumKey() bool { - return p.MapOfEnumKey != nil -} - -func (p *Inner) IsSetByte1() bool { - return p.Byte1 != nil -} - -func (p *Inner) IsSetDouble1() bool { - return p.Double1 != nil -} - -func (p *Inner) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("Inner(%+v)", *p) -} - -func (p *Inner) DeepEqual(ano *Inner) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Num) { - return false - } - if !p.Field2DeepEqual(ano.Desc) { - return false - } - if !p.Field3DeepEqual(ano.MapOfList) { - return false - } - if !p.Field4DeepEqual(ano.MapOfEnumKey) { - return false - } - if !p.Field5DeepEqual(ano.Byte1) { - return false - } - if !p.Field6DeepEqual(ano.Double1) { - return false - } - return true -} - -func (p *Inner) Field1DeepEqual(src int32) bool { - - if p.Num != src { - return false - } - return true -} -func (p *Inner) Field2DeepEqual(src *string) bool { - - if p.Desc == src { - return true - } else if p.Desc == nil || src == nil { - return false - } - if strings.Compare(*p.Desc, *src) != 0 { - return false - } - return true -} -func (p *Inner) Field3DeepEqual(src map[int64][]int64) bool { - - if len(p.MapOfList) != len(src) { - return false - } - for k, v := range p.MapOfList { - _src := src[k] - if len(v) != len(_src) { - return false - } - for i, v := range v { - _src1 := _src[i] - if v != _src1 { - return false - } - } - } - return true -} -func (p *Inner) Field4DeepEqual(src map[AEnum]int64) bool { - - if len(p.MapOfEnumKey) != len(src) { - return false - } - for k, v := range p.MapOfEnumKey { - _src := src[k] - if v != _src { - return false - } - } - return true -} -func (p *Inner) Field5DeepEqual(src *int8) bool { - - if p.Byte1 == src { - return true - } else if p.Byte1 == nil || src == nil { - return false - } - if *p.Byte1 != *src { - return false - } - return true -} -func (p *Inner) Field6DeepEqual(src *float64) bool { - - if p.Double1 == src { - return true - } else if p.Double1 == nil || src == nil { - return false - } - if *p.Double1 != *src { - return false - } - return true -} - -var fieldIDToName_Inner = map[int16]string{ - 1: "Num", - 2: "desc", - 3: "MapOfList", - 4: "MapOfEnumKey", - 5: "Byte1", - 6: "Double1", -} - -type Local struct { - L int32 `thrift:"l,1" frugal:"1,default,i32" json:"l"` - _unknownFields unknown.Fields -} - -func NewLocal() *Local { - return &Local{} -} - -func (p *Local) InitDefault() { -} - -func (p *Local) GetL() (v int32) { - return p.L -} -func (p *Local) SetL(val int32) { - p.L = val -} - -func (p *Local) CarryingUnknownFields() bool { - return len(p._unknownFields) > 0 -} - -func (p *Local) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("Local(%+v)", *p) -} - -func (p *Local) DeepEqual(ano *Local) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.L) { - return false - } - return true -} - -func (p *Local) Field1DeepEqual(src int32) bool { - - if p.L != src { - return false - } - return true -} - -var fieldIDToName_Local = map[int16]string{ - 1: "l", -} - -type FullStruct struct { - Left int32 `thrift:"Left,1,required" frugal:"1,required,i32" json:"Left"` - Right int32 `thrift:"Right,2,optional" frugal:"2,optional,i32" json:"Right,omitempty"` - Dummy []byte `thrift:"Dummy,3" frugal:"3,default,binary" json:"Dummy"` - InnerReq *Inner `thrift:"InnerReq,4" frugal:"4,default,Inner" json:"InnerReq"` - Status HTTPStatus `thrift:"status,5" frugal:"5,default,HTTPStatus" json:"status"` - Str string `thrift:"Str,6" frugal:"6,default,string" json:"Str"` - EnumList []HTTPStatus `thrift:"enum_list,7" frugal:"7,default,list" json:"enum_list"` - Strmap map[int32]string `thrift:"Strmap,8,optional" frugal:"8,optional,map" json:"Strmap,omitempty"` - Int64 int64 `thrift:"Int64,9" frugal:"9,default,i64" json:"Int64"` - IntList []int32 `thrift:"IntList,10,optional" frugal:"10,optional,list" json:"IntList,omitempty"` - LocalList []*Local `thrift:"localList,11" frugal:"11,default,list" json:"localList"` - StrLocalMap map[string]*Local `thrift:"StrLocalMap,12" frugal:"12,default,map" json:"StrLocalMap"` - NestList [][]int32 `thrift:"nestList,13" frugal:"13,default,list>" json:"nestList"` - RequiredIns *Local `thrift:"required_ins,14,required" frugal:"14,required,Local" json:"required_ins"` - NestMap map[string][]string `thrift:"nestMap,16" frugal:"16,default,map>" json:"nestMap"` - NestMap2 []map[string]HTTPStatus `thrift:"nestMap2,17" frugal:"17,default,list>" json:"nestMap2"` - EnumMap map[int32]HTTPStatus `thrift:"enum_map,18" frugal:"18,default,map" json:"enum_map"` - Strlist []string `thrift:"Strlist,19" frugal:"19,default,list" json:"Strlist"` - OptionalIns *Local `thrift:"optional_ins,20,optional" frugal:"20,optional,Local" json:"optional_ins,omitempty"` - AnotherInner *Inner `thrift:"AnotherInner,21" frugal:"21,default,Inner" json:"AnotherInner"` - OptNilList []string `thrift:"opt_nil_list,22,optional" frugal:"22,optional,list" json:"opt_nil_list,omitempty"` - NilList []string `thrift:"nil_list,23" frugal:"23,default,list" json:"nil_list"` - OptNilInsList []*Inner `thrift:"opt_nil_ins_list,24,optional" frugal:"24,optional,list" json:"opt_nil_ins_list,omitempty"` - NilInsList []*Inner `thrift:"nil_ins_list,25" frugal:"25,default,list" json:"nil_ins_list"` - OptStatus *HTTPStatus `thrift:"opt_status,26,optional" frugal:"26,optional,HTTPStatus" json:"opt_status,omitempty"` - EnumKeyMap map[HTTPStatus]*Local `thrift:"enum_key_map,27" frugal:"27,default,map" json:"enum_key_map"` - Complex map[HTTPStatus][]map[string]*Local `thrift:"complex,28" frugal:"28,default,map>>" json:"complex"` - I64Set []int64 `thrift:"i64Set,29" frugal:"29,default,set" json:"i64Set"` - Int16 int16 `thrift:"Int16,30" frugal:"30,default,i16" json:"Int16"` - IsSet bool `thrift:"isSet,31" frugal:"31,default,bool" json:"isSet"` - _unknownFields unknown.Fields -} - -func NewFullStruct() *FullStruct { - return &FullStruct{ - - Right: 3, - } -} - -func (p *FullStruct) InitDefault() { - p.Right = 3 -} - -func (p *FullStruct) GetLeft() (v int32) { - return p.Left -} - -var FullStruct_Right_DEFAULT int32 = 3 - -func (p *FullStruct) GetRight() (v int32) { - if !p.IsSetRight() { - return FullStruct_Right_DEFAULT - } - return p.Right -} - -func (p *FullStruct) GetDummy() (v []byte) { - return p.Dummy -} - -var FullStruct_InnerReq_DEFAULT *Inner - -func (p *FullStruct) GetInnerReq() (v *Inner) { - if !p.IsSetInnerReq() { - return FullStruct_InnerReq_DEFAULT - } - return p.InnerReq -} - -func (p *FullStruct) GetStatus() (v HTTPStatus) { - return p.Status -} - -func (p *FullStruct) GetStr() (v string) { - return p.Str -} - -func (p *FullStruct) GetEnumList() (v []HTTPStatus) { - return p.EnumList -} - -var FullStruct_Strmap_DEFAULT map[int32]string - -func (p *FullStruct) GetStrmap() (v map[int32]string) { - if !p.IsSetStrmap() { - return FullStruct_Strmap_DEFAULT - } - return p.Strmap -} - -func (p *FullStruct) GetInt64() (v int64) { - return p.Int64 -} - -var FullStruct_IntList_DEFAULT []int32 - -func (p *FullStruct) GetIntList() (v []int32) { - if !p.IsSetIntList() { - return FullStruct_IntList_DEFAULT - } - return p.IntList -} - -func (p *FullStruct) GetLocalList() (v []*Local) { - return p.LocalList -} - -func (p *FullStruct) GetStrLocalMap() (v map[string]*Local) { - return p.StrLocalMap -} - -func (p *FullStruct) GetNestList() (v [][]int32) { - return p.NestList -} - -var FullStruct_RequiredIns_DEFAULT *Local - -func (p *FullStruct) GetRequiredIns() (v *Local) { - if !p.IsSetRequiredIns() { - return FullStruct_RequiredIns_DEFAULT - } - return p.RequiredIns -} - -func (p *FullStruct) GetNestMap() (v map[string][]string) { - return p.NestMap -} - -func (p *FullStruct) GetNestMap2() (v []map[string]HTTPStatus) { - return p.NestMap2 -} - -func (p *FullStruct) GetEnumMap() (v map[int32]HTTPStatus) { - return p.EnumMap -} - -func (p *FullStruct) GetStrlist() (v []string) { - return p.Strlist -} - -var FullStruct_OptionalIns_DEFAULT *Local - -func (p *FullStruct) GetOptionalIns() (v *Local) { - if !p.IsSetOptionalIns() { - return FullStruct_OptionalIns_DEFAULT - } - return p.OptionalIns -} - -var FullStruct_AnotherInner_DEFAULT *Inner - -func (p *FullStruct) GetAnotherInner() (v *Inner) { - if !p.IsSetAnotherInner() { - return FullStruct_AnotherInner_DEFAULT - } - return p.AnotherInner -} - -var FullStruct_OptNilList_DEFAULT []string - -func (p *FullStruct) GetOptNilList() (v []string) { - if !p.IsSetOptNilList() { - return FullStruct_OptNilList_DEFAULT - } - return p.OptNilList -} - -func (p *FullStruct) GetNilList() (v []string) { - return p.NilList -} - -var FullStruct_OptNilInsList_DEFAULT []*Inner - -func (p *FullStruct) GetOptNilInsList() (v []*Inner) { - if !p.IsSetOptNilInsList() { - return FullStruct_OptNilInsList_DEFAULT - } - return p.OptNilInsList -} - -func (p *FullStruct) GetNilInsList() (v []*Inner) { - return p.NilInsList -} - -var FullStruct_OptStatus_DEFAULT HTTPStatus - -func (p *FullStruct) GetOptStatus() (v HTTPStatus) { - if !p.IsSetOptStatus() { - return FullStruct_OptStatus_DEFAULT - } - return *p.OptStatus -} - -func (p *FullStruct) GetEnumKeyMap() (v map[HTTPStatus]*Local) { - return p.EnumKeyMap -} - -func (p *FullStruct) GetComplex() (v map[HTTPStatus][]map[string]*Local) { - return p.Complex -} - -func (p *FullStruct) GetI64Set() (v []int64) { - return p.I64Set -} - -func (p *FullStruct) GetInt16() (v int16) { - return p.Int16 -} - -func (p *FullStruct) GetIsSet() (v bool) { - return p.IsSet -} -func (p *FullStruct) SetLeft(val int32) { - p.Left = val -} -func (p *FullStruct) SetRight(val int32) { - p.Right = val -} -func (p *FullStruct) SetDummy(val []byte) { - p.Dummy = val -} -func (p *FullStruct) SetInnerReq(val *Inner) { - p.InnerReq = val -} -func (p *FullStruct) SetStatus(val HTTPStatus) { - p.Status = val -} -func (p *FullStruct) SetStr(val string) { - p.Str = val -} -func (p *FullStruct) SetEnumList(val []HTTPStatus) { - p.EnumList = val -} -func (p *FullStruct) SetStrmap(val map[int32]string) { - p.Strmap = val -} -func (p *FullStruct) SetInt64(val int64) { - p.Int64 = val -} -func (p *FullStruct) SetIntList(val []int32) { - p.IntList = val -} -func (p *FullStruct) SetLocalList(val []*Local) { - p.LocalList = val -} -func (p *FullStruct) SetStrLocalMap(val map[string]*Local) { - p.StrLocalMap = val -} -func (p *FullStruct) SetNestList(val [][]int32) { - p.NestList = val -} -func (p *FullStruct) SetRequiredIns(val *Local) { - p.RequiredIns = val -} -func (p *FullStruct) SetNestMap(val map[string][]string) { - p.NestMap = val -} -func (p *FullStruct) SetNestMap2(val []map[string]HTTPStatus) { - p.NestMap2 = val -} -func (p *FullStruct) SetEnumMap(val map[int32]HTTPStatus) { - p.EnumMap = val -} -func (p *FullStruct) SetStrlist(val []string) { - p.Strlist = val -} -func (p *FullStruct) SetOptionalIns(val *Local) { - p.OptionalIns = val -} -func (p *FullStruct) SetAnotherInner(val *Inner) { - p.AnotherInner = val -} -func (p *FullStruct) SetOptNilList(val []string) { - p.OptNilList = val -} -func (p *FullStruct) SetNilList(val []string) { - p.NilList = val -} -func (p *FullStruct) SetOptNilInsList(val []*Inner) { - p.OptNilInsList = val -} -func (p *FullStruct) SetNilInsList(val []*Inner) { - p.NilInsList = val -} -func (p *FullStruct) SetOptStatus(val *HTTPStatus) { - p.OptStatus = val -} -func (p *FullStruct) SetEnumKeyMap(val map[HTTPStatus]*Local) { - p.EnumKeyMap = val -} -func (p *FullStruct) SetComplex(val map[HTTPStatus][]map[string]*Local) { - p.Complex = val -} -func (p *FullStruct) SetI64Set(val []int64) { - p.I64Set = val -} -func (p *FullStruct) SetInt16(val int16) { - p.Int16 = val -} -func (p *FullStruct) SetIsSet(val bool) { - p.IsSet = val -} - -func (p *FullStruct) CarryingUnknownFields() bool { - return len(p._unknownFields) > 0 -} - -func (p *FullStruct) IsSetRight() bool { - return p.Right != FullStruct_Right_DEFAULT -} - -func (p *FullStruct) IsSetInnerReq() bool { - return p.InnerReq != nil -} - -func (p *FullStruct) IsSetStrmap() bool { - return p.Strmap != nil -} - -func (p *FullStruct) IsSetIntList() bool { - return p.IntList != nil -} - -func (p *FullStruct) IsSetRequiredIns() bool { - return p.RequiredIns != nil -} - -func (p *FullStruct) IsSetOptionalIns() bool { - return p.OptionalIns != nil -} - -func (p *FullStruct) IsSetAnotherInner() bool { - return p.AnotherInner != nil -} - -func (p *FullStruct) IsSetOptNilList() bool { - return p.OptNilList != nil -} - -func (p *FullStruct) IsSetOptNilInsList() bool { - return p.OptNilInsList != nil -} - -func (p *FullStruct) IsSetOptStatus() bool { - return p.OptStatus != nil -} - -func (p *FullStruct) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("FullStruct(%+v)", *p) -} - -func (p *FullStruct) DeepEqual(ano *FullStruct) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Left) { - return false - } - if !p.Field2DeepEqual(ano.Right) { - return false - } - if !p.Field3DeepEqual(ano.Dummy) { - return false - } - if !p.Field4DeepEqual(ano.InnerReq) { - return false - } - if !p.Field5DeepEqual(ano.Status) { - return false - } - if !p.Field6DeepEqual(ano.Str) { - return false - } - if !p.Field7DeepEqual(ano.EnumList) { - return false - } - if !p.Field8DeepEqual(ano.Strmap) { - return false - } - if !p.Field9DeepEqual(ano.Int64) { - return false - } - if !p.Field10DeepEqual(ano.IntList) { - return false - } - if !p.Field11DeepEqual(ano.LocalList) { - return false - } - if !p.Field12DeepEqual(ano.StrLocalMap) { - return false - } - if !p.Field13DeepEqual(ano.NestList) { - return false - } - if !p.Field14DeepEqual(ano.RequiredIns) { - return false - } - if !p.Field16DeepEqual(ano.NestMap) { - return false - } - if !p.Field17DeepEqual(ano.NestMap2) { - return false - } - if !p.Field18DeepEqual(ano.EnumMap) { - return false - } - if !p.Field19DeepEqual(ano.Strlist) { - return false - } - if !p.Field20DeepEqual(ano.OptionalIns) { - return false - } - if !p.Field21DeepEqual(ano.AnotherInner) { - return false - } - if !p.Field22DeepEqual(ano.OptNilList) { - return false - } - if !p.Field23DeepEqual(ano.NilList) { - return false - } - if !p.Field24DeepEqual(ano.OptNilInsList) { - return false - } - if !p.Field25DeepEqual(ano.NilInsList) { - return false - } - if !p.Field26DeepEqual(ano.OptStatus) { - return false - } - if !p.Field27DeepEqual(ano.EnumKeyMap) { - return false - } - if !p.Field28DeepEqual(ano.Complex) { - return false - } - if !p.Field29DeepEqual(ano.I64Set) { - return false - } - if !p.Field30DeepEqual(ano.Int16) { - return false - } - if !p.Field31DeepEqual(ano.IsSet) { - return false - } - return true -} - -func (p *FullStruct) Field1DeepEqual(src int32) bool { - - if p.Left != src { - return false - } - return true -} -func (p *FullStruct) Field2DeepEqual(src int32) bool { - - if p.Right != src { - return false - } - return true -} -func (p *FullStruct) Field3DeepEqual(src []byte) bool { - - if bytes.Compare(p.Dummy, src) != 0 { - return false - } - return true -} -func (p *FullStruct) Field4DeepEqual(src *Inner) bool { - - if !p.InnerReq.DeepEqual(src) { - return false - } - return true -} -func (p *FullStruct) Field5DeepEqual(src HTTPStatus) bool { - - if p.Status != src { - return false - } - return true -} -func (p *FullStruct) Field6DeepEqual(src string) bool { - - if strings.Compare(p.Str, src) != 0 { - return false - } - return true -} -func (p *FullStruct) Field7DeepEqual(src []HTTPStatus) bool { - - if len(p.EnumList) != len(src) { - return false - } - for i, v := range p.EnumList { - _src := src[i] - if v != _src { - return false - } - } - return true -} -func (p *FullStruct) Field8DeepEqual(src map[int32]string) bool { - - if len(p.Strmap) != len(src) { - return false - } - for k, v := range p.Strmap { - _src := src[k] - if strings.Compare(v, _src) != 0 { - return false - } - } - return true -} -func (p *FullStruct) Field9DeepEqual(src int64) bool { - - if p.Int64 != src { - return false - } - return true -} -func (p *FullStruct) Field10DeepEqual(src []int32) bool { - - if len(p.IntList) != len(src) { - return false - } - for i, v := range p.IntList { - _src := src[i] - if v != _src { - return false - } - } - return true -} -func (p *FullStruct) Field11DeepEqual(src []*Local) bool { - - if len(p.LocalList) != len(src) { - return false - } - for i, v := range p.LocalList { - _src := src[i] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *FullStruct) Field12DeepEqual(src map[string]*Local) bool { - - if len(p.StrLocalMap) != len(src) { - return false - } - for k, v := range p.StrLocalMap { - _src := src[k] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *FullStruct) Field13DeepEqual(src [][]int32) bool { - - if len(p.NestList) != len(src) { - return false - } - for i, v := range p.NestList { - _src := src[i] - if len(v) != len(_src) { - return false - } - for i, v := range v { - _src1 := _src[i] - if v != _src1 { - return false - } - } - } - return true -} -func (p *FullStruct) Field14DeepEqual(src *Local) bool { - - if !p.RequiredIns.DeepEqual(src) { - return false - } - return true -} -func (p *FullStruct) Field16DeepEqual(src map[string][]string) bool { - - if len(p.NestMap) != len(src) { - return false - } - for k, v := range p.NestMap { - _src := src[k] - if len(v) != len(_src) { - return false - } - for i, v := range v { - _src1 := _src[i] - if strings.Compare(v, _src1) != 0 { - return false - } - } - } - return true -} -func (p *FullStruct) Field17DeepEqual(src []map[string]HTTPStatus) bool { - - if len(p.NestMap2) != len(src) { - return false - } - for i, v := range p.NestMap2 { - _src := src[i] - if len(v) != len(_src) { - return false - } - for k, v := range v { - _src1 := _src[k] - if v != _src1 { - return false - } - } - } - return true -} -func (p *FullStruct) Field18DeepEqual(src map[int32]HTTPStatus) bool { - - if len(p.EnumMap) != len(src) { - return false - } - for k, v := range p.EnumMap { - _src := src[k] - if v != _src { - return false - } - } - return true -} -func (p *FullStruct) Field19DeepEqual(src []string) bool { - - if len(p.Strlist) != len(src) { - return false - } - for i, v := range p.Strlist { - _src := src[i] - if strings.Compare(v, _src) != 0 { - return false - } - } - return true -} -func (p *FullStruct) Field20DeepEqual(src *Local) bool { - - if !p.OptionalIns.DeepEqual(src) { - return false - } - return true -} -func (p *FullStruct) Field21DeepEqual(src *Inner) bool { - - if !p.AnotherInner.DeepEqual(src) { - return false - } - return true -} -func (p *FullStruct) Field22DeepEqual(src []string) bool { - - if len(p.OptNilList) != len(src) { - return false - } - for i, v := range p.OptNilList { - _src := src[i] - if strings.Compare(v, _src) != 0 { - return false - } - } - return true -} -func (p *FullStruct) Field23DeepEqual(src []string) bool { - - if len(p.NilList) != len(src) { - return false - } - for i, v := range p.NilList { - _src := src[i] - if strings.Compare(v, _src) != 0 { - return false - } - } - return true -} -func (p *FullStruct) Field24DeepEqual(src []*Inner) bool { - - if len(p.OptNilInsList) != len(src) { - return false - } - for i, v := range p.OptNilInsList { - _src := src[i] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *FullStruct) Field25DeepEqual(src []*Inner) bool { - - if len(p.NilInsList) != len(src) { - return false - } - for i, v := range p.NilInsList { - _src := src[i] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *FullStruct) Field26DeepEqual(src *HTTPStatus) bool { - - if p.OptStatus == src { - return true - } else if p.OptStatus == nil || src == nil { - return false - } - if *p.OptStatus != *src { - return false - } - return true -} -func (p *FullStruct) Field27DeepEqual(src map[HTTPStatus]*Local) bool { - - if len(p.EnumKeyMap) != len(src) { - return false - } - for k, v := range p.EnumKeyMap { - _src := src[k] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *FullStruct) Field28DeepEqual(src map[HTTPStatus][]map[string]*Local) bool { - - if len(p.Complex) != len(src) { - return false - } - for k, v := range p.Complex { - _src := src[k] - if len(v) != len(_src) { - return false - } - for i, v := range v { - _src1 := _src[i] - if len(v) != len(_src1) { - return false - } - for k, v := range v { - _src2 := _src1[k] - if !v.DeepEqual(_src2) { - return false - } - } - } - } - return true -} -func (p *FullStruct) Field29DeepEqual(src []int64) bool { - - if len(p.I64Set) != len(src) { - return false - } - for i, v := range p.I64Set { - _src := src[i] - if v != _src { - return false - } - } - return true -} -func (p *FullStruct) Field30DeepEqual(src int16) bool { - - if p.Int16 != src { - return false - } - return true -} -func (p *FullStruct) Field31DeepEqual(src bool) bool { - - if p.IsSet != src { - return false - } - return true -} - -var fieldIDToName_FullStruct = map[int16]string{ - 1: "Left", - 2: "Right", - 3: "Dummy", - 4: "InnerReq", - 5: "status", - 6: "Str", - 7: "enum_list", - 8: "Strmap", - 9: "Int64", - 10: "IntList", - 11: "localList", - 12: "StrLocalMap", - 13: "nestList", - 14: "required_ins", - 16: "nestMap", - 17: "nestMap2", - 18: "enum_map", - 19: "Strlist", - 20: "optional_ins", - 21: "AnotherInner", - 22: "opt_nil_list", - 23: "nil_list", - 24: "opt_nil_ins_list", - 25: "nil_ins_list", - 26: "opt_status", - 27: "enum_key_map", - 28: "complex", - 29: "i64Set", - 30: "Int16", - 31: "isSet", -} - -type MixedStruct struct { - Left int32 `thrift:"Left,1,required" frugal:"1,required,i32" json:"Left"` - Dummy []byte `thrift:"Dummy,3" frugal:"3,default,binary" json:"Dummy"` - Str string `thrift:"Str,6" frugal:"6,default,string" json:"Str"` - EnumList []HTTPStatus `thrift:"enum_list,7" frugal:"7,default,list" json:"enum_list"` - Int64 int64 `thrift:"Int64,9" frugal:"9,default,i64" json:"Int64"` - IntList []int32 `thrift:"IntList,10,optional" frugal:"10,optional,list" json:"IntList,omitempty"` - LocalList []*Local `thrift:"localList,11" frugal:"11,default,list" json:"localList"` - StrLocalMap map[string]*Local `thrift:"StrLocalMap,12" frugal:"12,default,map" json:"StrLocalMap"` - NestList [][]int32 `thrift:"nestList,13" frugal:"13,default,list>" json:"nestList"` - RequiredIns *Local `thrift:"required_ins,14,required" frugal:"14,required,Local" json:"required_ins"` - OptionalIns *Local `thrift:"optional_ins,20,optional" frugal:"20,optional,Local" json:"optional_ins,omitempty"` - AnotherInner *Inner `thrift:"AnotherInner,21" frugal:"21,default,Inner" json:"AnotherInner"` - EnumKeyMap map[HTTPStatus]*Local `thrift:"enum_key_map,27" frugal:"27,default,map" json:"enum_key_map"` - _unknownFields unknown.Fields -} - -func NewMixedStruct() *MixedStruct { - return &MixedStruct{} -} - -func (p *MixedStruct) InitDefault() { -} - -func (p *MixedStruct) GetLeft() (v int32) { - return p.Left -} - -func (p *MixedStruct) GetDummy() (v []byte) { - return p.Dummy -} - -func (p *MixedStruct) GetStr() (v string) { - return p.Str -} - -func (p *MixedStruct) GetEnumList() (v []HTTPStatus) { - return p.EnumList -} - -func (p *MixedStruct) GetInt64() (v int64) { - return p.Int64 -} - -var MixedStruct_IntList_DEFAULT []int32 - -func (p *MixedStruct) GetIntList() (v []int32) { - if !p.IsSetIntList() { - return MixedStruct_IntList_DEFAULT - } - return p.IntList -} - -func (p *MixedStruct) GetLocalList() (v []*Local) { - return p.LocalList -} - -func (p *MixedStruct) GetStrLocalMap() (v map[string]*Local) { - return p.StrLocalMap -} - -func (p *MixedStruct) GetNestList() (v [][]int32) { - return p.NestList -} - -var MixedStruct_RequiredIns_DEFAULT *Local - -func (p *MixedStruct) GetRequiredIns() (v *Local) { - if !p.IsSetRequiredIns() { - return MixedStruct_RequiredIns_DEFAULT - } - return p.RequiredIns -} - -var MixedStruct_OptionalIns_DEFAULT *Local - -func (p *MixedStruct) GetOptionalIns() (v *Local) { - if !p.IsSetOptionalIns() { - return MixedStruct_OptionalIns_DEFAULT - } - return p.OptionalIns -} - -var MixedStruct_AnotherInner_DEFAULT *Inner - -func (p *MixedStruct) GetAnotherInner() (v *Inner) { - if !p.IsSetAnotherInner() { - return MixedStruct_AnotherInner_DEFAULT - } - return p.AnotherInner -} - -func (p *MixedStruct) GetEnumKeyMap() (v map[HTTPStatus]*Local) { - return p.EnumKeyMap -} -func (p *MixedStruct) SetLeft(val int32) { - p.Left = val -} -func (p *MixedStruct) SetDummy(val []byte) { - p.Dummy = val -} -func (p *MixedStruct) SetStr(val string) { - p.Str = val -} -func (p *MixedStruct) SetEnumList(val []HTTPStatus) { - p.EnumList = val -} -func (p *MixedStruct) SetInt64(val int64) { - p.Int64 = val -} -func (p *MixedStruct) SetIntList(val []int32) { - p.IntList = val -} -func (p *MixedStruct) SetLocalList(val []*Local) { - p.LocalList = val -} -func (p *MixedStruct) SetStrLocalMap(val map[string]*Local) { - p.StrLocalMap = val -} -func (p *MixedStruct) SetNestList(val [][]int32) { - p.NestList = val -} -func (p *MixedStruct) SetRequiredIns(val *Local) { - p.RequiredIns = val -} -func (p *MixedStruct) SetOptionalIns(val *Local) { - p.OptionalIns = val -} -func (p *MixedStruct) SetAnotherInner(val *Inner) { - p.AnotherInner = val -} -func (p *MixedStruct) SetEnumKeyMap(val map[HTTPStatus]*Local) { - p.EnumKeyMap = val -} - -func (p *MixedStruct) CarryingUnknownFields() bool { - return len(p._unknownFields) > 0 -} - -func (p *MixedStruct) IsSetIntList() bool { - return p.IntList != nil -} - -func (p *MixedStruct) IsSetRequiredIns() bool { - return p.RequiredIns != nil -} - -func (p *MixedStruct) IsSetOptionalIns() bool { - return p.OptionalIns != nil -} - -func (p *MixedStruct) IsSetAnotherInner() bool { - return p.AnotherInner != nil -} - -func (p *MixedStruct) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("MixedStruct(%+v)", *p) -} - -func (p *MixedStruct) DeepEqual(ano *MixedStruct) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - if !p.Field1DeepEqual(ano.Left) { - return false - } - if !p.Field3DeepEqual(ano.Dummy) { - return false - } - if !p.Field6DeepEqual(ano.Str) { - return false - } - if !p.Field7DeepEqual(ano.EnumList) { - return false - } - if !p.Field9DeepEqual(ano.Int64) { - return false - } - if !p.Field10DeepEqual(ano.IntList) { - return false - } - if !p.Field11DeepEqual(ano.LocalList) { - return false - } - if !p.Field12DeepEqual(ano.StrLocalMap) { - return false - } - if !p.Field13DeepEqual(ano.NestList) { - return false - } - if !p.Field14DeepEqual(ano.RequiredIns) { - return false - } - if !p.Field20DeepEqual(ano.OptionalIns) { - return false - } - if !p.Field21DeepEqual(ano.AnotherInner) { - return false - } - if !p.Field27DeepEqual(ano.EnumKeyMap) { - return false - } - return true -} - -func (p *MixedStruct) Field1DeepEqual(src int32) bool { - - if p.Left != src { - return false - } - return true -} -func (p *MixedStruct) Field3DeepEqual(src []byte) bool { - - if bytes.Compare(p.Dummy, src) != 0 { - return false - } - return true -} -func (p *MixedStruct) Field6DeepEqual(src string) bool { - - if strings.Compare(p.Str, src) != 0 { - return false - } - return true -} -func (p *MixedStruct) Field7DeepEqual(src []HTTPStatus) bool { - - if len(p.EnumList) != len(src) { - return false - } - for i, v := range p.EnumList { - _src := src[i] - if v != _src { - return false - } - } - return true -} -func (p *MixedStruct) Field9DeepEqual(src int64) bool { - - if p.Int64 != src { - return false - } - return true -} -func (p *MixedStruct) Field10DeepEqual(src []int32) bool { - - if len(p.IntList) != len(src) { - return false - } - for i, v := range p.IntList { - _src := src[i] - if v != _src { - return false - } - } - return true -} -func (p *MixedStruct) Field11DeepEqual(src []*Local) bool { - - if len(p.LocalList) != len(src) { - return false - } - for i, v := range p.LocalList { - _src := src[i] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *MixedStruct) Field12DeepEqual(src map[string]*Local) bool { - - if len(p.StrLocalMap) != len(src) { - return false - } - for k, v := range p.StrLocalMap { - _src := src[k] - if !v.DeepEqual(_src) { - return false - } - } - return true -} -func (p *MixedStruct) Field13DeepEqual(src [][]int32) bool { - - if len(p.NestList) != len(src) { - return false - } - for i, v := range p.NestList { - _src := src[i] - if len(v) != len(_src) { - return false - } - for i, v := range v { - _src1 := _src[i] - if v != _src1 { - return false - } - } - } - return true -} -func (p *MixedStruct) Field14DeepEqual(src *Local) bool { - - if !p.RequiredIns.DeepEqual(src) { - return false - } - return true -} -func (p *MixedStruct) Field20DeepEqual(src *Local) bool { - - if !p.OptionalIns.DeepEqual(src) { - return false - } - return true -} -func (p *MixedStruct) Field21DeepEqual(src *Inner) bool { - - if !p.AnotherInner.DeepEqual(src) { - return false - } - return true -} -func (p *MixedStruct) Field27DeepEqual(src map[HTTPStatus]*Local) bool { - - if len(p.EnumKeyMap) != len(src) { - return false - } - for k, v := range p.EnumKeyMap { - _src := src[k] - if !v.DeepEqual(_src) { - return false - } - } - return true -} - -var fieldIDToName_MixedStruct = map[int16]string{ - 1: "Left", - 3: "Dummy", - 6: "Str", - 7: "enum_list", - 9: "Int64", - 10: "IntList", - 11: "localList", - 12: "StrLocalMap", - 13: "nestList", - 14: "required_ins", - 20: "optional_ins", - 21: "AnotherInner", - 27: "enum_key_map", -} - -type EmptyStruct struct { - _unknownFields unknown.Fields -} - -func NewEmptyStruct() *EmptyStruct { - return &EmptyStruct{} -} - -func (p *EmptyStruct) InitDefault() { -} - -func (p *EmptyStruct) CarryingUnknownFields() bool { - return len(p._unknownFields) > 0 -} - -func (p *EmptyStruct) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("EmptyStruct(%+v)", *p) -} - -func (p *EmptyStruct) DeepEqual(ano *EmptyStruct) bool { - if p == ano { - return true - } else if p == nil || ano == nil { - return false - } - return true -} - -var fieldIDToName_EmptyStruct = map[int16]string{} diff --git a/pkg/protocol/bthrift/test/test.thrift b/pkg/protocol/bthrift/test/test.thrift deleted file mode 100644 index 355c1a9b16..0000000000 --- a/pkg/protocol/bthrift/test/test.thrift +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2023 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -namespace go test - -enum AEnum { - A = 1, - B = 2, -} - -struct Inner { - 1: optional i32 Num = 5, - 2: optional string desc, - 3: optional map> MapOfList, - 4: optional map MapOfEnumKey, - 5: optional byte Byte1, - 6: optional double Double1, -} - -enum HTTPStatus { - OK = 200, - NOT_FOUND = 404 -} - -struct Local { - 1: i32 l, -} - -struct FullStruct { - 1: required i32 Left, - 2: optional i32 Right = 3, - 3: binary Dummy, - 4: Inner InnerReq, - 5: HTTPStatus status, - 6: string Str, - 7: list enum_list, - 8: optional map Strmap, - 9: i64 Int64, - 10: optional list IntList, - 11: list localList, - 12: map StrLocalMap, - 13: list> nestList, - 14: required Local required_ins, - 16: map> nestMap, - 17: list> nestMap2, - 18: map enum_map, - 19: list Strlist, - 20: optional Local optional_ins, - 21: Inner AnotherInner, - 22: optional list opt_nil_list, - 23: list nil_list, - 24: optional list opt_nil_ins_list, - 25: list nil_ins_list, - 26: optional HTTPStatus opt_status, - 27: map enum_key_map, - 28: map>> complex, - 29: set i64Set, - 30: i16 Int16, - 31: bool isSet, -} - -struct MixedStruct { - 1: required i32 Left, - 3: binary Dummy, - 6: string Str, - 7: list enum_list, - 9: i64 Int64, - 10: optional list IntList, - 11: list localList, - 12: map StrLocalMap, - 13: list> nestList, - 14: required Local required_ins, - 20: optional Local optional_ins, - 21: Inner AnotherInner, - 27: map enum_key_map, -} - -struct EmptyStruct { -} - diff --git a/pkg/protocol/bthrift/test/unknown_test.go b/pkg/protocol/bthrift/test/unknown_test.go deleted file mode 100644 index beeae56d84..0000000000 --- a/pkg/protocol/bthrift/test/unknown_test.go +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Copyright 2023 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test - -import ( - "bytes" - "reflect" - "testing" - - tt "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift/test/kitex_gen/test" -) - -var fullReq *test.FullStruct - -func init() { - desc := "aa" - status := test.HTTPStatus_NOT_FOUND - byte1 := int8(1) - double1 := 1.3 - fullReq = &test.FullStruct{ - Left: 32, - Right: 45, - Dummy: []byte("test"), - InnerReq: &test.Inner{ - Num: 6, - Desc: &desc, - MapOfList: map[int64][]int64{42: {1, 2}}, - MapOfEnumKey: map[test.AEnum]int64{test.AEnum_A: 1, test.AEnum_B: 2}, - Byte1: &byte1, - Double1: &double1, - }, - Status: test.HTTPStatus_OK, - Str: "str", - EnumList: []test.HTTPStatus{test.HTTPStatus_NOT_FOUND, test.HTTPStatus_OK}, - Strmap: map[int32]string{ - 10: "aa", - 11: "bb", - }, - Int64: 5, - IntList: []int32{11, 22, 33}, - LocalList: []*test.Local{{L: 33}, nil}, - StrLocalMap: map[string]*test.Local{ - "bbb": { - L: 22, - }, - "ccc": { - L: 11, - }, - "ddd": nil, - }, - NestList: [][]int32{{3, 4}, {5, 6}}, - RequiredIns: &test.Local{ - L: 55, - }, - NestMap: map[string][]string{"aa": {"cc", "bb"}, "bb": {"xx", "yy"}}, - NestMap2: []map[string]test.HTTPStatus{{"ok": test.HTTPStatus_OK}}, - EnumMap: map[int32]test.HTTPStatus{ - 0: test.HTTPStatus_NOT_FOUND, - 1: test.HTTPStatus_OK, - }, - Strlist: []string{"mm", "nn"}, - OptStatus: &status, - Complex: map[test.HTTPStatus][]map[string]*test.Local{ - test.HTTPStatus_OK: { - {"": &test.Local{L: 3}}, - {"c": nil, "d": &test.Local{L: 42}}, - nil, - }, - test.HTTPStatus_NOT_FOUND: nil, - }, - I64Set: []int64{1, 2, 3}, - Int16: 98, - IsSet: true, - } -} - -func TestOnlyUnknownField(t *testing.T) { - l := fullReq.BLength() - buf := make([]byte, l) - ll := fullReq.FastWriteNocopy(buf, nil) - tt.Assert(t, ll == l) - - unknown := &test.EmptyStruct{} - ll, err := unknown.FastRead(buf) - tt.Assert(t, err == nil) - tt.Assert(t, ll == l) - unknownL := unknown.BLength() - tt.Assert(t, unknownL == l) - unknownBuf := make([]byte, unknownL) - writeL := unknown.FastWriteNocopy(unknownBuf, nil) - tt.Assert(t, writeL == l) - tt.Assert(t, bytes.Equal(buf, unknownBuf)) - - // test get unknown fields - fields, err := bthrift.GetUnknownFields(unknown) - tt.Assert(t, err == nil) - l, err = bthrift.UnknownFieldsLength(fields) - tt.Assert(t, err == nil) - buf = make([]byte, l) - _, err = bthrift.WriteUnknownFields(buf, fields) - tt.Assert(t, err == nil) - tt.Assert(t, bytes.Equal(buf, reflect.ValueOf(unknown).Elem().FieldByName("_unknownFields").Bytes())) -} - -func TestPartialUnknownField(t *testing.T) { - l := fullReq.BLength() - buf := make([]byte, l) - ll := fullReq.FastWriteNocopy(buf, nil) - tt.Assert(t, ll == l) - compare := &test.FullStruct{} - ll, err := compare.FastRead(buf) - tt.Assert(t, err == nil) - tt.Assert(t, ll == l) - - unknown := &test.MixedStruct{} - ll, err = unknown.FastRead(buf) - tt.Assert(t, err == nil) - tt.Assert(t, ll == l) - unknownL := unknown.BLength() - unknownBuf := make([]byte, unknownL) - writeL := unknown.FastWriteNocopy(unknownBuf, nil) - tt.Assert(t, writeL == unknownL) - compare1 := &test.FullStruct{} - ll, err = compare1.FastRead(unknownBuf) - tt.Assert(t, err == nil) - tt.Assert(t, ll == unknownL) - tt.Assert(t, compare1.DeepEqual(compare)) -} - -func TestNoUnknownField(t *testing.T) { - l := fullReq.BLength() - buf := make([]byte, l) - ll := fullReq.FastWriteNocopy(buf, nil) - tt.Assert(t, ll == l) - - ori := &test.FullStruct{} - ll, err := ori.FastRead(buf) - tt.Assert(t, err == nil) - tt.Assert(t, ll == l) - - // required fields - tt.Assert(t, ori.Field11DeepEqual([]*test.Local{{L: 33}, test.NewLocal()})) - tt.Assert(t, ori.Field12DeepEqual(map[string]*test.Local{ - "bbb": {L: 22}, "ccc": {L: 11}, "ddd": {}, - })) - tt.Assert(t, ori.Field21DeepEqual(test.NewInner())) - tt.Assert(t, ori.Field28DeepEqual(map[test.HTTPStatus][]map[string]*test.Local{ - test.HTTPStatus_OK: { - {"": &test.Local{L: 3}}, - {"c": {}, "d": &test.Local{L: 42}}, - nil, - }, - test.HTTPStatus_NOT_FOUND: nil, - })) - ori.LocalList[1] = nil - ori.StrLocalMap["ddd"] = nil - ori.AnotherInner = nil - ori.Complex[test.HTTPStatus_OK][1]["c"] = nil - - tt.Assert(t, ori.Field1DeepEqual(fullReq.Left)) - tt.Assert(t, ori.Field2DeepEqual(fullReq.Right)) - tt.Assert(t, ori.Field3DeepEqual(fullReq.Dummy)) - tt.Assert(t, ori.Field4DeepEqual(fullReq.InnerReq)) - tt.Assert(t, ori.Field5DeepEqual(fullReq.Status)) - tt.Assert(t, ori.Field6DeepEqual(fullReq.Str)) - tt.Assert(t, ori.Field7DeepEqual(fullReq.EnumList)) - tt.Assert(t, ori.Field8DeepEqual(fullReq.Strmap)) - tt.Assert(t, ori.Field9DeepEqual(fullReq.Int64)) - tt.Assert(t, ori.Field10DeepEqual(fullReq.IntList)) - tt.Assert(t, ori.Field11DeepEqual(fullReq.LocalList)) - tt.Assert(t, ori.Field12DeepEqual(fullReq.StrLocalMap)) - tt.Assert(t, ori.Field13DeepEqual(fullReq.NestList)) - tt.Assert(t, ori.Field14DeepEqual(fullReq.RequiredIns)) - tt.Assert(t, ori.Field16DeepEqual(fullReq.NestMap)) - tt.Assert(t, ori.Field17DeepEqual(fullReq.NestMap2)) - tt.Assert(t, ori.Field18DeepEqual(fullReq.EnumMap)) - tt.Assert(t, ori.Field19DeepEqual(fullReq.Strlist)) - tt.Assert(t, ori.Field20DeepEqual(fullReq.OptionalIns)) - tt.Assert(t, ori.Field21DeepEqual(fullReq.AnotherInner)) - tt.Assert(t, ori.Field22DeepEqual(fullReq.OptNilList)) - tt.Assert(t, ori.Field23DeepEqual(fullReq.NilList)) - tt.Assert(t, ori.Field24DeepEqual(fullReq.OptNilInsList)) - tt.Assert(t, ori.Field25DeepEqual(fullReq.NilInsList)) - tt.Assert(t, ori.Field26DeepEqual(fullReq.OptStatus)) - tt.Assert(t, ori.Field27DeepEqual(fullReq.EnumKeyMap)) - tt.Assert(t, ori.Field28DeepEqual(fullReq.Complex)) -} - -func BenchmarkOnlyUnknownField(b *testing.B) { - l := fullReq.BLength() - buf := make([]byte, l) - ll := fullReq.FastWriteNocopy(buf, nil) - tt.Assert(b, ll == l) - - unknownBuf := make([]byte, l) - for i := 0; i < b.N; i++ { - unknown := &test.EmptyStruct{} - _, _ = unknown.FastRead(buf) - unknown.FastWriteNocopy(unknownBuf, nil) - } -} - -//func TestCorruptWrite(t *testing.T) { -// local := &test.Local{L: 3} -// ufs := unknown.Fields{&unknown.Field{Type: 1000}} -// local.SetUnknown(ufs) -// -// defer func() { -// e := recover() -// if strings.Contains(e.(error).Error(), "unknown data type 1000") { -// return -// } -// tt.Assert(t, false, e) -// }() -// _ = local.BLength() -// tt.Assert(t, false) -//} -// -//func TestCorruptRead(t *testing.T) { -// local := &test.Local{L: 3} -// ufs := unknown.Fields{&unknown.Field{Name: "test", Type: unknown.TString, Value: "str"}} -// local.SetUnknown(ufs) -// l := local.BLength() -// buf := make([]byte, l) -// ll := local.FastWriteNocopy(buf, nil) -// tt.Assert(t, ll == l) -// buf[7] = 200 -// -// var local2 test.Local -// _, err := local2.FastRead(buf) -// tt.Assert(t, err != nil) -// tt.Assert(t, strings.Contains(err.Error(), "unknown data type 200")) -//} diff --git a/pkg/protocol/bthrift/unknown.go b/pkg/protocol/bthrift/unknown.go index ccc133bb6b..fa815e9820 100644 --- a/pkg/protocol/bthrift/unknown.go +++ b/pkg/protocol/bthrift/unknown.go @@ -17,13 +17,9 @@ package bthrift import ( - "errors" - "fmt" - "reflect" - + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/unknownfields" "github.com/cloudwego/thriftgo/generator/golang/extension/unknown" - - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // UnknownField is used to describe an unknown field. @@ -36,363 +32,61 @@ type UnknownField struct { Value interface{} } -// GetUnknownFields deserialize unknownFields stored in v to a list of *UnknownFields. -func GetUnknownFields(v interface{}) (fields []UnknownField, err error) { - var buf []byte - rv := reflect.ValueOf(v) - if rv.Kind() == reflect.Ptr && !rv.IsNil() { - rv = rv.Elem() +func fromGopkgUnknownFields(ff []unknownfields.UnknownField) []UnknownField { + if ff == nil { + return nil } - if rv.Kind() != reflect.Struct { - return nil, fmt.Errorf("%T is not a struct type", v) + ret := make([]UnknownField, len(ff)) + for i := range ff { + f := &ff[i] + ret[i].Name = "" // this field is useless + ret[i].ID = f.ID + ret[i].Type = int(f.Type) + ret[i].KeyType = int(f.KeyType) + ret[i].ValType = int(f.ValType) + ret[i].Value = f.Value } - if unknownField := rv.FieldByName("_unknownFields"); !unknownField.IsValid() { - return nil, fmt.Errorf("%T has no field named '_unknownFields'", v) - } else { - buf = unknownField.Bytes() - } - return ConvertUnknownFields(buf) + return ret } -// ConvertUnknownFields converts buf to deserialized unknown fields. -func ConvertUnknownFields(buf unknown.Fields) (fields []UnknownField, err error) { - if len(buf) == 0 { - return nil, errors.New("_unknownFields is empty") +func toGopkgUnknownFields(ff []UnknownField) []unknownfields.UnknownField { + if ff == nil { + return nil } - var offset int - var l int - var name string - var fieldTypeId thrift.TType - var fieldId int16 - var f UnknownField - for { - if offset == len(buf) { - return - } - name, fieldTypeId, fieldId, l, err = Binary.ReadFieldBegin(buf[offset:]) - offset += l - if err != nil { - return nil, fmt.Errorf("read field %d begin error: %v", fieldId, err) - } - l, err = readUnknownField(&f, buf[offset:], name, fieldTypeId, fieldId) - offset += l - if err != nil { - return nil, fmt.Errorf("read unknown field %d error: %v", fieldId, err) - } - fields = append(fields, f) + ret := make([]unknownfields.UnknownField, len(ff)) + for i := range ff { + f := &ff[i] + ret[i].ID = f.ID + ret[i].Type = thrift.TType(f.Type) + ret[i].KeyType = thrift.TType(f.KeyType) + ret[i].ValType = thrift.TType(f.ValType) + ret[i].Value = f.Value } + return ret } -func readUnknownField(f *UnknownField, buf []byte, name string, fieldType thrift.TType, id int16) (length int, err error) { - var size int - var l int - f.Name = name - f.ID = id - f.Type = int(fieldType) - switch fieldType { - case thrift.BOOL: - f.Value, l, err = Binary.ReadBool(buf[length:]) - length += l - case thrift.BYTE: - f.Value, l, err = Binary.ReadByte(buf[length:]) - length += l - case thrift.I16: - f.Value, l, err = Binary.ReadI16(buf[length:]) - length += l - case thrift.I32: - f.Value, l, err = Binary.ReadI32(buf[length:]) - length += l - case thrift.I64: - f.Value, l, err = Binary.ReadI64(buf[length:]) - length += l - case thrift.DOUBLE: - f.Value, l, err = Binary.ReadDouble(buf[length:]) - length += l - case thrift.STRING: - f.Value, l, err = Binary.ReadString(buf[length:]) - length += l - case thrift.SET: - var ttype thrift.TType - ttype, size, l, err = Binary.ReadSetBegin(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read set begin error: %w", err) - } - f.ValType = int(ttype) - set := make([]UnknownField, size) - for i := 0; i < size; i++ { - l, err2 := readUnknownField(&set[i], buf[length:], "", thrift.TType(f.ValType), int16(i)) - length += l - if err2 != nil { - return length, fmt.Errorf("read set elem error: %w", err2) - } - } - l, err = Binary.ReadSetEnd(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read set end error: %w", err) - } - f.Value = set - case thrift.LIST: - var ttype thrift.TType - ttype, size, l, err = Binary.ReadListBegin(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read list begin error: %w", err) - } - f.ValType = int(ttype) - list := make([]UnknownField, size) - for i := 0; i < size; i++ { - l, err2 := readUnknownField(&list[i], buf[length:], "", thrift.TType(f.ValType), int16(i)) - length += l - if err2 != nil { - return length, fmt.Errorf("read list elem error: %w", err2) - } - } - l, err = Binary.ReadListEnd(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read list end error: %w", err) - } - f.Value = list - case thrift.MAP: - var kttype, vttype thrift.TType - kttype, vttype, size, l, err = Binary.ReadMapBegin(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read map begin error: %w", err) - } - f.KeyType = int(kttype) - f.ValType = int(vttype) - flatMap := make([]UnknownField, size*2) - for i := 0; i < size; i++ { - l, err2 := readUnknownField(&flatMap[2*i], buf[length:], "", thrift.TType(f.KeyType), int16(i)) - length += l - if err2 != nil { - return length, fmt.Errorf("read map key error: %w", err2) - } - l, err2 = readUnknownField(&flatMap[2*i+1], buf[length:], "", thrift.TType(f.ValType), int16(i)) - length += l - if err2 != nil { - return length, fmt.Errorf("read map value error: %w", err2) - } - } - l, err = Binary.ReadMapEnd(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read map end error: %w", err) - } - f.Value = flatMap - case thrift.STRUCT: - _, l, err = Binary.ReadStructBegin(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read struct begin error: %w", err) - } - var field UnknownField - var fields []UnknownField - for { - name, fieldTypeID, fieldID, l, err := Binary.ReadFieldBegin(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read field begin error: %w", err) - } - if fieldTypeID == thrift.STOP { - break - } - l, err = readUnknownField(&field, buf[length:], name, fieldTypeID, fieldID) - length += l - if err != nil { - return length, fmt.Errorf("read struct field error: %w", err) - } - l, err = Binary.ReadFieldEnd(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read field end error: %w", err) - } - fields = append(fields, field) - } - l, err = Binary.ReadStructEnd(buf[length:]) - length += l - if err != nil { - return length, fmt.Errorf("read struct end error: %w", err) - } - f.Value = fields - default: - return length, fmt.Errorf("unknown data type %d", f.Type) - } - if err != nil { - return length, err - } - return +// GetUnknownFields deserialize unknownfields stored in v to a list of *UnknownFields. +// Deprecated: use the method under github.com/cloudwego/gopkg/protocol/thrift/unknownfields +func GetUnknownFields(v interface{}) (fields []UnknownField, err error) { + ff, err := unknownfields.GetUnknownFields(v) + return fromGopkgUnknownFields(ff), err } -// UnknownFieldsLength returns the length of fs. -func UnknownFieldsLength(fs []UnknownField) (int, error) { - l := 0 - for _, f := range fs { - l += Binary.FieldBeginLength(f.Name, thrift.TType(f.Type), f.ID) - ll, err := unknownFieldLength(&f) - l += ll - if err != nil { - return l, err - } - l += Binary.FieldEndLength() - } - return l, nil +// ConvertUnknownFields converts buf to deserialized unknown fields. +// Deprecated: use the method under github.com/cloudwego/gopkg/protocol/thrift/unknownfields +func ConvertUnknownFields(buf unknown.Fields) (fields []UnknownField, err error) { + ff, err := unknownfields.ConvertUnknownFields(buf) + return fromGopkgUnknownFields(ff), err } -func unknownFieldLength(f *UnknownField) (length int, err error) { - // use constants to avoid some type assert - switch f.Type { - case unknown.TBool: - length += Binary.BoolLength(false) - case unknown.TByte: - length += Binary.ByteLength(0) - case unknown.TDouble: - length += Binary.DoubleLength(0) - case unknown.TI16: - length += Binary.I16Length(0) - case unknown.TI32: - length += Binary.I32Length(0) - case unknown.TI64: - length += Binary.I64Length(0) - case unknown.TString: - length += Binary.StringLength(f.Value.(string)) - case unknown.TSet: - vs := f.Value.([]UnknownField) - length += Binary.SetBeginLength(thrift.TType(f.ValType), len(vs)) - for _, v := range vs { - l, err := unknownFieldLength(&v) - length += l - if err != nil { - return length, err - } - } - length += Binary.SetEndLength() - case unknown.TList: - vs := f.Value.([]UnknownField) - length += Binary.ListBeginLength(thrift.TType(f.ValType), len(vs)) - for _, v := range vs { - l, err := unknownFieldLength(&v) - length += l - if err != nil { - return length, err - } - } - length += Binary.ListEndLength() - case unknown.TMap: - kvs := f.Value.([]UnknownField) - length += Binary.MapBeginLength(thrift.TType(f.KeyType), thrift.TType(f.ValType), len(kvs)/2) - for i := 0; i < len(kvs); i += 2 { - l, err := unknownFieldLength(&kvs[i]) - length += l - if err != nil { - return length, err - } - l, err = unknownFieldLength(&kvs[i+1]) - length += l - if err != nil { - return length, err - } - } - length += Binary.MapEndLength() - case unknown.TStruct: - fs := f.Value.([]UnknownField) - length += Binary.StructBeginLength(f.Name) - l, err := UnknownFieldsLength(fs) - length += l - if err != nil { - return length, err - } - length += Binary.FieldStopLength() - length += Binary.StructEndLength() - default: - return length, fmt.Errorf("unknown data type %d", f.Type) - } - return +// UnknownFieldsLength returns the length of fs. +// Deprecated: use the method under github.com/cloudwego/gopkg/protocol/thrift/unknownfields +func UnknownFieldsLength(fs []UnknownField) (int, error) { + return unknownfields.UnknownFieldsLength(toGopkgUnknownFields(fs)) } // WriteUnknownFields writes fs into buf, and return written offset of the buf. +// Deprecated: use the method under github.com/cloudwego/gopkg/protocol/thrift/unknownfields func WriteUnknownFields(buf []byte, fs []UnknownField) (offset int, err error) { - for _, f := range fs { - offset += Binary.WriteFieldBegin(buf[offset:], f.Name, thrift.TType(f.Type), f.ID) - l, err := writeUnknownField(buf[offset:], &f) - offset += l - if err != nil { - return offset, err - } - offset += Binary.WriteFieldEnd(buf[offset:]) - } - return offset, nil -} - -func writeUnknownField(buf []byte, f *UnknownField) (offset int, err error) { - switch f.Type { - case unknown.TBool: - offset += Binary.WriteBool(buf, f.Value.(bool)) - case unknown.TByte: - offset += Binary.WriteByte(buf, f.Value.(int8)) - case unknown.TDouble: - offset += Binary.WriteDouble(buf, f.Value.(float64)) - case unknown.TI16: - offset += Binary.WriteI16(buf, f.Value.(int16)) - case unknown.TI32: - offset += Binary.WriteI32(buf, f.Value.(int32)) - case unknown.TI64: - offset += Binary.WriteI64(buf, f.Value.(int64)) - case unknown.TString: - offset += Binary.WriteString(buf, f.Value.(string)) - case unknown.TSet: - vs := f.Value.([]UnknownField) - offset += Binary.WriteSetBegin(buf, thrift.TType(f.ValType), len(vs)) - for _, v := range vs { - l, err := writeUnknownField(buf[offset:], &v) - offset += l - if err != nil { - return offset, err - } - } - offset += Binary.WriteSetEnd(buf[offset:]) - case unknown.TList: - vs := f.Value.([]UnknownField) - offset += Binary.WriteListBegin(buf, thrift.TType(f.ValType), len(vs)) - for _, v := range vs { - l, err := writeUnknownField(buf[offset:], &v) - offset += l - if err != nil { - return offset, err - } - } - offset += Binary.WriteListEnd(buf[offset:]) - case unknown.TMap: - kvs := f.Value.([]UnknownField) - offset += Binary.WriteMapBegin(buf, thrift.TType(f.KeyType), thrift.TType(f.ValType), len(kvs)/2) - for i := 0; i < len(kvs); i += 2 { - l, err := writeUnknownField(buf[offset:], &kvs[i]) - offset += l - if err != nil { - return offset, err - } - l, err = writeUnknownField(buf[offset:], &kvs[i+1]) - offset += l - if err != nil { - return offset, err - } - } - offset += Binary.WriteMapEnd(buf[offset:]) - case unknown.TStruct: - fs := f.Value.([]UnknownField) - offset += Binary.WriteStructBegin(buf, f.Name) - l, err := WriteUnknownFields(buf[offset:], fs) - offset += l - if err != nil { - return offset, err - } - offset += Binary.WriteFieldStop(buf[offset:]) - offset += Binary.WriteStructEnd(buf[offset:]) - default: - return offset, fmt.Errorf("unknown data type %d", f.Type) - } - return + return unknownfields.WriteUnknownFields(buf, toGopkgUnknownFields(fs)) } diff --git a/pkg/protocol/bthrift/unknown_test.go b/pkg/protocol/bthrift/unknown_test.go new file mode 100644 index 0000000000..2325748a0c --- /dev/null +++ b/pkg/protocol/bthrift/unknown_test.go @@ -0,0 +1,31 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bthrift + +import ( + "reflect" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestUnknownFieldTypeConvert(t *testing.T) { + ff := []UnknownField{{ID: 1, Type: 2, KeyType: 3, ValType: 4, Value: 5}} + ff1 := fromGopkgUnknownFields(toGopkgUnknownFields(ff)) + test.Assert(t, len(ff) == len(ff1)) + test.Assert(t, reflect.DeepEqual(ff, ff1)) +} diff --git a/pkg/remote/codec/thrift/skip_decoder.go b/pkg/remote/codec/thrift/skip_decoder.go deleted file mode 100644 index 13bb045c2b..0000000000 --- a/pkg/remote/codec/thrift/skip_decoder.go +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "encoding/binary" - "errors" - "fmt" - - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/remote/codec/perrors" -) - -// skipDecoder is used to parse the input byte-by-byte and skip the thrift payload -// for making use of Frugal and FastCodec in standard Thrift Binary Protocol scenario. -type skipDecoder struct { - remote.ByteBuffer - - r int - buf []byte -} - -func (p *skipDecoder) init() (err error) { - p.r = 0 - p.buf, err = p.Peek(p.ReadableLen()) - return -} - -func (p *skipDecoder) loadmore(n int) error { - // trigger underlying conn to read more - _, err := p.Peek(n) - if err == nil { - // read as much as possible, luckly, we will have a full buffer - // then we no need to call p.Peek many times - p.buf, err = p.Peek(p.ReadableLen()) - } - return err -} - -func (p *skipDecoder) next(n int) ([]byte, error) { - if len(p.buf)-p.r < n { - if err := p.loadmore(p.r + n); err != nil { - return nil, err - } - // after calling p.loadmore, p.buf MUST be at least (p.r + n) len - } - ret := p.buf[p.r : p.r+n] - p.r += n - return ret, nil -} - -func (p *skipDecoder) NextStruct() (buf []byte, err error) { - // should be ok with init, just one less p.Peek call - if err := p.init(); err != nil { - return nil, err - } - if err := p.skip(thrift.STRUCT, thrift.DEFAULT_RECURSION_DEPTH); err != nil { - return nil, err - } - return p.Next(p.r) -} - -// skip skips bytes for a specific type. -// After calling skip, p.r contains the len of the type. -// Since BinaryProtocol calls Next of remote.ByteBuffer many times when reading a struct, -// we don't use it for performance concerns -func (p *skipDecoder) skip(typeID thrift.TType, maxDepth int) error { - if maxDepth <= 0 { - return thrift.NewTProtocolExceptionWithType(thrift.DEPTH_LIMIT, errors.New("depth limit exceeded")) - } - switch typeID { - case thrift.BOOL, thrift.BYTE: - if _, err := p.next(1); err != nil { - return err - } - case thrift.I16: - if _, err := p.next(2); err != nil { - return err - } - case thrift.I32: - if _, err := p.next(4); err != nil { - return err - } - case thrift.I64, thrift.DOUBLE: - if _, err := p.next(8); err != nil { - return err - } - case thrift.STRING: - b, err := p.next(4) - if err != nil { - return err - } - sz := int(binary.BigEndian.Uint32(b)) - if sz < 0 { - return perrors.InvalidDataLength - } - if _, err := p.next(sz); err != nil { - return err - } - case thrift.STRUCT: - for { - b, err := p.next(1) // TType - if err != nil { - return err - } - tp := thrift.TType(b[0]) - if tp == thrift.STOP { - break - } - if _, err := p.next(2); err != nil { // Field ID - return err - } - if err := p.skip(tp, maxDepth-1); err != nil { - return err - } - } - case thrift.MAP: - b, err := p.next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len - if err != nil { - return err - } - kt, vt, sz := thrift.TType(b[0]), thrift.TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) - if sz < 0 { - return perrors.InvalidDataLength - } - for i := int32(0); i < sz; i++ { - if err := p.skip(kt, maxDepth-1); err != nil { - return err - } - if err := p.skip(vt, maxDepth-1); err != nil { - return err - } - } - case thrift.SET, thrift.LIST: - b, err := p.next(5) // 1 byte value type, 4 bytes Len - if err != nil { - return err - } - vt, sz := thrift.TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) - if sz < 0 { - return perrors.InvalidDataLength - } - for i := int32(0); i < sz; i++ { - if err := p.skip(vt, maxDepth-1); err != nil { - return err - } - } - default: - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("unknown data type %d", typeID)) - } - return nil -} diff --git a/pkg/remote/codec/thrift/skip_decoder_test.go b/pkg/remote/codec/thrift/skip_decoder_test.go deleted file mode 100644 index 0b8b8dc2ca..0000000000 --- a/pkg/remote/codec/thrift/skip_decoder_test.go +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "testing" - - "github.com/cloudwego/kitex/internal/test" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote" -) - -func makeByteBufferForSkipDecoderTest() remote.ByteBuffer { - tProt := NewBinaryProtocol(remote.NewReaderWriterBuffer(1024)) - defer tProt.Recycle() - tProt.WriteStructBegin("testStruct") - // 1. Byte - tProt.WriteFieldBegin("Byte", thrift.BYTE, 1) - tProt.WriteByte('1') - tProt.WriteFieldEnd() - // 2. Bool - tProt.WriteFieldBegin("Bool", thrift.BOOL, 2) - tProt.WriteBool(true) - tProt.WriteFieldEnd() - // 3. I16 - tProt.WriteFieldBegin("I16", thrift.I16, 3) - tProt.WriteI16(2) - tProt.WriteFieldEnd() - // 4. I32 - tProt.WriteFieldBegin("I32", thrift.I32, 4) - tProt.WriteI32(3) - tProt.WriteFieldEnd() - // 5. I64 - tProt.WriteFieldBegin("I64", thrift.I64, 5) - tProt.WriteI64(4) - tProt.WriteFieldEnd() - // 6. Double - tProt.WriteFieldBegin("Double", thrift.DOUBLE, 6) - tProt.WriteDouble(5) - tProt.WriteFieldEnd() - // 7. String - tProt.WriteFieldBegin("String", thrift.STRING, 7) - tProt.WriteString("6") - tProt.WriteFieldEnd() - // 8. Map - tProt.WriteFieldBegin("Map", thrift.MAP, 8) - tProt.WriteMapBegin(thrift.I32, thrift.I32, 1) - tProt.WriteI32(7) - tProt.WriteI32(8) - tProt.WriteMapEnd() - tProt.WriteFieldEnd() - // 9. Set - tProt.WriteFieldBegin("Set", thrift.SET, 9) - tProt.WriteSetBegin(thrift.I32, 1) - tProt.WriteI32(9) - tProt.WriteSetEnd() - tProt.WriteFieldEnd() - // 10. List - tProt.WriteFieldBegin("List", thrift.LIST, 10) - tProt.WriteListBegin(thrift.I32, 1) - tProt.WriteI32(9) - tProt.WriteListEnd() - tProt.WriteFieldEnd() - - tProt.WriteFieldStop() - tProt.WriteStructEnd() - return tProt.ByteBuffer() -} - -func TestSkipDecoder_NextStruct(t *testing.T) { - buf := makeByteBufferForSkipDecoderTest() - defer buf.Release(nil) - length := buf.ReadableLen() - sd := skipDecoder{ByteBuffer: buf} - b, err := sd.NextStruct() - test.Assert(t, err == nil) - test.Assert(t, len(b) == length) -} - -func BenchmarkSkipDecoder(b *testing.B) { - buf := makeByteBufferForSkipDecoderTest() - for i := 0; i < b.N; i++ { - b, _ := buf.Peek(buf.ReadableLen()) - bb := remote.NewReaderBuffer(b) - - sd := skipDecoder{ByteBuffer: bb} - sd.NextStruct() - - bb.Release(nil) - } - buf.Release(nil) -} diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 47fb17efef..beafd77638 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -22,8 +22,9 @@ import ( "fmt" "io" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift" + + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" @@ -41,7 +42,13 @@ const ( FastWrite CodecType = 0b0001 FastRead CodecType = 0b0010 - FastReadWrite = FastRead | FastWrite + FastReadWrite = FastRead | FastWrite + + FrugalWrite CodecType = 0b0100 + FrugalRead CodecType = 0b1000 + + FrugalReadWrite = FrugalWrite | FrugalRead + EnableSkipDecoder CodecType = 0b10000 ) @@ -89,8 +96,15 @@ type thriftCodec struct { CodecType } +// IsSet returns true if t is set +func (c thriftCodec) IsSet(t CodecType) bool { + return c.CodecType&t != 0 +} + // Marshal implements the remote.PayloadCodec interface. func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { + // TODO(xiaost): Refactor the code after v0.11.0 is released. Unifying checking and fallback logic. + // prepare info methodName := message.RPCInfo().Invocation().MethodName() if methodName == "" { @@ -99,54 +113,68 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re msgType := message.MessageType() seqID := message.RPCInfo().Invocation().SeqID() - data, err := getValidData(methodName, message) - if err != nil { + // ???? for fixing resp==nil, err==nil? don't know + if err := codec.NewDataIfNeeded(methodName, message); err != nil { return err } + data := message.Data() + if message.MessageType() == remote.Exception { + // if remote.Exception, we always use fastcodec + if transErr, ok := data.(*remote.TransError); ok { + ex := thrift.NewApplicationException(transErr.TypeID(), transErr.Error()) + return encodeFastThrift(out, methodName, msgType, seqID, ex) + } else if err, ok := data.(error); ok { + ex := thrift.NewApplicationException(remote.InternalError, err.Error()) + return encodeFastThrift(out, methodName, msgType, seqID, ex) + } else { + return fmt.Errorf("got %T for remote.Exception", data) + } + } // encode with hyper codec - // NOTE: to ensure hyperMarshalEnabled is inlined so split the check logic, or it may cause performance loss - if c.hyperMarshalEnabled() && hyperMarshalAvailable(data) { + if c.IsSet(FrugalWrite) && hyperMarshalAvailable(data) { return c.hyperMarshal(out, methodName, msgType, seqID, data) } // encode with FastWrite - if c.CodecType&FastWrite != 0 { - if msg, ok := data.(bthrift.ThriftFastCodec); ok { + if c.IsSet(FastWrite) { + if msg, ok := data.(thrift.FastCodec); ok { return encodeFastThrift(out, methodName, msgType, seqID, msg) } } // fallback to old thrift way (slow) - if err = encodeBasicThrift(out, ctx, methodName, msgType, seqID, data, message.RPCRole()); err == nil || err != errEncodeMismatchMsgType { + if err := encodeBasicThrift(out, ctx, methodName, msgType, seqID, data, message.RPCRole()); err == nil || err != errEncodeMismatchMsgType { return err } - // Basic can be used for disabling frugal, we need to check it - if c.CodecType != Basic && hyperMarshalAvailable(data) { - // fallback to frugal when the generated code is using slim template - return c.hyperMarshal(out, methodName, msgType, seqID, data) + // if user only wants to use Basic we never try fallback to frugal or fastcodec + if c.CodecType != Basic { + // try FrugalWrite < - > FastWrite fallback + if msg, ok := data.(thrift.FastCodec); ok { + return encodeFastThrift(out, methodName, msgType, seqID, msg) + } + if hyperMarshalAvailable(data) { // slim template? + return c.hyperMarshal(out, methodName, msgType, seqID, data) + } } - return errEncodeMismatchMsgType } // encodeFastThrift encode with the FastCodec way -func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg bthrift.ThriftFastCodec) error { +func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg thrift.FastCodec) error { nw, _ := out.(remote.NocopyWrite) // nocopy write is a special implementation of linked buffer, only bytebuffer implement NocopyWrite do FastWrite - msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) - msgEndLen := bthrift.Binary.MessageEndLength() - buf, err := out.Malloc(msgBeginLen + msg.BLength() + msgEndLen) + msgBeginLen := thrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) + buf, err := out.Malloc(msgBeginLen + msg.BLength()) if err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error())) } // If fast write enabled, the underlying buffer maybe large than the correct buffer, // so we need to save the mallocLen before fast write and correct the real mallocLen after codec mallocLen := out.MallocLen() - offset := bthrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) - offset += msg.FastWriteNocopy(buf[offset:], nw) - bthrift.Binary.WriteMessageEnd(buf[offset:]) + offset := thrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) + _ = msg.FastWriteNocopy(buf[offset:], nw) if nw == nil { // if nw is nil, FastWrite will act in Copy mode. return nil @@ -160,7 +188,7 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string return err } tProt := NewBinaryProtocol(out) - if err := tProt.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID); err != nil { + if err := tProt.WriteMessageBegin(method, athrift.TMessageType(msgType), seqID); err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error())) } if err := marshalBasicThriftData(ctx, tProt, data, method, rpcRole); err != nil { @@ -175,6 +203,8 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string // Unmarshal implements the remote.PayloadCodec interface. func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { + // TODO(xiaost): Refactor the code after v0.11.0 is released. Unifying checking and fallback logic. + tProt := NewBinaryProtocol(in) methodName, msgType, seqID, err := tProt.ReadMessageBegin() if err != nil { @@ -195,8 +225,8 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r // decode thrift data data := message.Data() - msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, msgType, seqID) - dataLen := message.PayloadLen() - msgBeginLen - bthrift.Binary.MessageEndLength() + msgBeginLen := thrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) + dataLen := message.PayloadLen() - msgBeginLen // For Buffer Protocol, dataLen would be negative. Set it to zero so as not to confuse if dataLen < 0 { dataLen = 0 @@ -209,7 +239,6 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r if err != nil { return err } - if err = tProt.ReadMessageEnd(); err != nil { return remote.NewTransError(remote.ProtocolError, err) } @@ -242,12 +271,12 @@ func (c thriftCodec) Name() string { // MessageWriter write to thrift.TProtocol type MessageWriter interface { - Write(oprot thrift.TProtocol) error + Write(oprot athrift.TProtocol) error } // MessageReader read from thrift.TProtocol type MessageReader interface { - Read(oprot thrift.TProtocol) error + Read(oprot athrift.TProtocol) error } type genericWriter interface { // used by pkg/generic @@ -261,35 +290,15 @@ type genericReader interface { // used by pkg/generic // MessageWriterWithMethodWithContext write to thrift.TProtocol // TODO(marina.sakai): remove it after we use the new genericWriter interface type MessageWriterWithMethodWithContext interface { - Write(ctx context.Context, method string, oprot thrift.TProtocol) error + Write(ctx context.Context, method string, oprot athrift.TProtocol) error } // MessageReaderWithMethodWithContext read from thrift.TProtocol with method // TODO(marina.sakai): remove it after we use the new genericReader interface type MessageReaderWithMethodWithContext interface { - Read(ctx context.Context, method string, dataLen int, iprot thrift.TProtocol) error + Read(ctx context.Context, method string, dataLen int, iprot athrift.TProtocol) error } // ThriftMsgFastCodec ... -// Deprecated: use `bthrift.ThriftFastCodec` -type ThriftMsgFastCodec = bthrift.ThriftFastCodec - -func getValidData(methodName string, message remote.Message) (interface{}, error) { - if err := codec.NewDataIfNeeded(methodName, message); err != nil { - return nil, err - } - data := message.Data() - if message.MessageType() != remote.Exception { - return data, nil - } - transErr, isTransErr := data.(*remote.TransError) - if !isTransErr { - if err, isError := data.(error); isError { - encodeErr := bthrift.NewApplicationException(remote.InternalError, err.Error()) - return encodeErr, nil - } - return nil, errors.New("exception relay need error type data") - } - encodeErr := bthrift.NewApplicationException(transErr.TypeID(), transErr.Error()) - return encodeErr, nil -} +// Deprecated: use `github.com/cloudwego/gopkg/protocol/thrift.FastCodec` +type ThriftMsgFastCodec = thrift.FastCodec diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index b15028b3dc..673141e482 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -21,9 +21,9 @@ import ( "fmt" "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) @@ -43,15 +43,15 @@ func MarshalThriftData(ctx context.Context, codec remote.PayloadCodec, data inte // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) // It will allocate a new buffer and encode to it func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([]byte, error) { + // TODO(xiaost): Refactor the code after v0.11.0 is released. Unifying checking and fallback logic. + // encode with hyper codec - // NOTE: to ensure hyperMarshalEnabled is inlined so split the check logic, or it may cause performance loss - if c.hyperMarshalEnabled() && hyperMarshalAvailable(data) { + if c.IsSet(FrugalWrite) && hyperMarshalAvailable(data) { return c.hyperMarshalBody(data) } - // encode with FastWrite - if c.CodecType&FastWrite != 0 { - if msg, ok := data.(bthrift.ThriftFastCodec); ok { + if c.IsSet(FastWrite) { + if msg, ok := data.(thrift.FastCodec); ok { payloadSize := msg.BLength() payload := mcache.Malloc(payloadSize) msg.FastWriteNocopy(payload, nil) @@ -68,10 +68,10 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ return nil, err } - // TODO: Remove the fallback code after skip decoder is stable + // TODO(xiaost): Deprecate the code by using cloudwebgo/gopkg in v0.12.0 // fallback to old thrift way (slow) - transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize) - tProt := thrift.NewTBinaryProtocol(transport, true, true) + transport := athrift.NewTMemoryBufferLen(marshalThriftBufferSize) + tProt := athrift.NewTBinaryProtocol(transport, true, true) if err := marshalBasicThriftData(ctx, tProt, data, "", -1); err != nil { return nil, err } @@ -92,7 +92,7 @@ func verifyMarshalBasicThriftDataType(data interface{}) error { // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) // It uses the old thrift way which is much slower than FastCodec and Frugal -func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}, method string, rpcRole remote.RPCRole) error { +func marshalBasicThriftData(ctx context.Context, tProt athrift.TProtocol, data interface{}, method string, rpcRole remote.RPCRole) error { var err error switch msg := data.(type) { case MessageWriter: @@ -111,23 +111,25 @@ func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data in } // UnmarshalThriftException decode thrift exception from tProt -// If your input is []byte, you can wrap it with `NewBinaryProtocol(remote.NewReaderBuffer(buf))` -func UnmarshalThriftException(tProt thrift.TProtocol) error { - exception := bthrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") - if err := exception.Read(tProt); err != nil { - return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal Exception failed: %s", err.Error())) +// TODO: this func should be removed in the future. it's exposed accidentally. +// Deprecated: Use `SkipDecoder` + `ApplicationException` of `cloudwego/gopkg/protocol/thrift` instead. +func UnmarshalThriftException(tProt athrift.TProtocol) error { + d := thrift.NewSkipDecoder(tProt.Transport()) + defer d.Release() + b, err := d.Next(thrift.STRUCT) + if err != nil { + return err } - if err := tProt.ReadMessageEnd(); err != nil { - return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal, ReadMessageEnd failed: %s", err.Error())) + ex := thrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") + if _, err := ex.FastRead(b); err != nil { + return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal Exception failed: %s", err.Error())) } - return remote.NewTransError(exception.TypeId(), exception) + return remote.NewTransError(ex.TypeId(), ex) } // UnmarshalThriftData only decodes the data (after methodName, msgType and seqId) // It will decode from the given buffer. -// Note: -// 1. `method` is only used for generic calls -// 2. if the buf contains an exception, you should call UnmarshalThriftException instead. +// NOTE: `method` is required for generic calls func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method string, buf []byte, data interface{}) error { c, ok := codec.(*thriftCodec) if !ok { @@ -141,20 +143,16 @@ func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method return err } -func (c thriftCodec) fastMessageUnmarshalEnabled() bool { - return c.CodecType&FastRead != 0 -} - func (c thriftCodec) fastMessageUnmarshalAvailable(data interface{}, payloadLen int) bool { if payloadLen == 0 && c.CodecType&EnableSkipDecoder == 0 { return false } - _, ok := data.(bthrift.ThriftFastCodec) + _, ok := data.(thrift.FastCodec) return ok } func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { - msg := data.(bthrift.ThriftFastCodec) + msg := data.(thrift.FastCodec) if dataLen > 0 { buf, err := tProt.next(dataLen) if err != nil { @@ -181,20 +179,25 @@ func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, data // method is only used for generic calls func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProtocol, method string, data interface{}, rpcRole remote.RPCRole, dataLen int) error { // decode with hyper unmarshal - if c.hyperMessageUnmarshalEnabled() && c.hyperMessageUnmarshalAvailable(data, dataLen) { + if c.IsSet(FrugalRead) && c.hyperMessageUnmarshalAvailable(data, dataLen) { return c.hyperUnmarshal(tProt, data, dataLen) } // decode with FastRead - if c.fastMessageUnmarshalEnabled() && c.fastMessageUnmarshalAvailable(data, dataLen) { + if c.IsSet(FastRead) && c.fastMessageUnmarshalAvailable(data, dataLen) { return c.fastUnmarshal(tProt, data, dataLen) } if err := verifyUnmarshalBasicThriftDataType(data); err != nil { - // Basic can be used for disabling frugal, we need to check it - if c.CodecType != Basic && c.hyperMessageUnmarshalAvailable(data, dataLen) { - // fallback to frugal when the generated code is using slim template - return c.hyperUnmarshal(tProt, data, dataLen) + // if user only wants to use Basic we never try fallback to frugal or fastcodec + if c.CodecType != Basic { + // try FrugalRead < - > FastRead fallback + if c.fastMessageUnmarshalAvailable(data, dataLen) { + return c.fastUnmarshal(tProt, data, dataLen) + } + if c.hyperMessageUnmarshalAvailable(data, dataLen) { // slim template? + return c.hyperUnmarshal(tProt, data, dataLen) + } } return err } @@ -205,7 +208,7 @@ func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProto func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { if dataLen > 0 { - buf, err := tProt.next(dataLen - bthrift.Binary.MessageEndLength()) + buf, err := tProt.next(dataLen) if err != nil { return remote.NewTransError(remote.ProtocolError, err) } @@ -238,7 +241,7 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error { } // decodeBasicThriftData decode thrift body the old way (slow) -func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method string, rpcRole remote.RPCRole, dataLen int, data interface{}) error { +func decodeBasicThriftData(ctx context.Context, tProt athrift.TProtocol, method string, rpcRole remote.RPCRole, dataLen int, data interface{}) error { var err error switch t := data.(type) { case MessageReader: @@ -257,10 +260,10 @@ func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method s } func getSkippedStructBuffer(tProt *BinaryProtocol) ([]byte, error) { - sd := skipDecoder{ByteBuffer: tProt.trans} - buf, err := sd.NextStruct() + sd := thrift.NewSkipDecoder(tProt.trans) + buf, err := sd.Next(thrift.STRUCT) if err != nil { - return nil, remote.NewTransError(remote.ProtocolError, err).AppendMessage("caught in SkipDecoder NextStruct phase") + return nil, remote.NewTransError(remote.ProtocolError, err).AppendMessage("caught in SkipDecoder Next phase") } return buf, nil } diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index ec6a14c1cb..a75b88e3c8 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -22,10 +22,11 @@ import ( "strings" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" ) @@ -47,8 +48,8 @@ func TestMarshalBasicThriftData(t *testing.T) { test.Assert(t, err == errEncodeMismatchMsgType, err) }) t.Run("valid-data", func(t *testing.T) { - transport := thrift.NewTMemoryBufferLen(1024) - tProt := thrift.NewTBinaryProtocol(transport, true, true) + transport := athrift.NewTMemoryBufferLen(1024) + tProt := athrift.NewTBinaryProtocol(transport, true, true) err := marshalBasicThriftData(context.Background(), tProt, mocks.ToApacheCodec(mockReq), "", -1) test.Assert(t, err == nil, err) result := transport.Bytes() @@ -161,19 +162,18 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { func TestUnmarshalThriftException(t *testing.T) { // prepare exception thrift binary - transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize) - tProt := thrift.NewTBinaryProtocol(transport, true, true) errMessage := "test: invalid protocol" - exc := bthrift.NewApplicationException(thrift.INVALID_PROTOCOL, errMessage) - err := exc.Write(tProt) - test.Assert(t, err == nil, err) + exc := thrift.NewApplicationException(thrift.INVALID_PROTOCOL, errMessage) + b := make([]byte, exc.BLength()) + n := exc.FastWrite(b) + test.Assert(t, n == len(b), n) // unmarshal - tProtRead := NewBinaryProtocol(remote.NewReaderBuffer(transport.Bytes())) - err = UnmarshalThriftException(tProtRead) + tProtRead := NewBinaryProtocol(remote.NewReaderBuffer(b)) + err := UnmarshalThriftException(tProtRead) transErr, ok := err.(*remote.TransError) test.Assert(t, ok, err) - test.Assert(t, transErr.TypeID() == thrift.INVALID_PROTOCOL, transErr) + test.Assert(t, transErr.TypeID() == athrift.INVALID_PROTOCOL, transErr) test.Assert(t, transErr.Error() == errMessage, transErr) } @@ -197,5 +197,5 @@ func Test_getSkippedStructBuffer(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultThrift)) _, err := getSkippedStructBuffer(tProt) test.Assert(t, err != nil, err) - test.Assert(t, strings.Contains(err.Error(), "caught in SkipDecoder NextStruct phase")) + test.Assert(t, strings.Contains(err.Error(), "caught in SkipDecoder Next phase")) } diff --git a/pkg/remote/codec/thrift/thrift_frugal.go b/pkg/remote/codec/thrift/thrift_frugal.go index 042cfd2467..6cda88f700 100644 --- a/pkg/remote/codec/thrift/thrift_frugal.go +++ b/pkg/remote/codec/thrift/thrift_frugal.go @@ -29,25 +29,13 @@ import ( "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/frugal" + "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) -const ( - // 0b0001 and 0b0010 are used for FastWrite and FastRead, so Frugal starts from 0b0100 - FrugalWrite CodecType = 0b0100 - FrugalRead CodecType = 0b1000 - - FrugalReadWrite = FrugalWrite | FrugalRead -) - -// hyperMarshalEnabled indicates that if there are high priority message codec for current platform. -func (c thriftCodec) hyperMarshalEnabled() bool { - return c.CodecType&FrugalWrite != 0 -} +// TODO(xiaost): rename hyper -> frugal after v0.11.0 is released // hyperMarshalAvailable indicates that if high priority message codec is available. func hyperMarshalAvailable(data interface{}) bool { @@ -58,11 +46,6 @@ func hyperMarshalAvailable(data interface{}) bool { return true } -// hyperMessageUnmarshalEnabled indicates that if there are high priority message codec for current platform. -func (c thriftCodec) hyperMessageUnmarshalEnabled() bool { - return c.CodecType&FrugalRead != 0 -} - // hyperMessageUnmarshalAvailable indicates that if high priority message codec is available. func (c thriftCodec) hyperMessageUnmarshalAvailable(data interface{}, payloadLen int) bool { if payloadLen == 0 && c.CodecType&EnableSkipDecoder == 0 { @@ -79,24 +62,21 @@ func (c thriftCodec) hyperMarshal(out remote.ByteBuffer, methodName string, msgT seqID int32, data interface{}, ) error { // calculate and malloc message buffer - msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) - msgEndLen := bthrift.Binary.MessageEndLength() + msgBeginLen := thrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) objectLen := frugal.EncodedSize(data) - buf, err := out.Malloc(msgBeginLen + objectLen + msgEndLen) + buf, err := out.Malloc(msgBeginLen + objectLen) if err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error())) } mallocLen := out.MallocLen() // encode message - offset := bthrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) + offset := thrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) nw, _ := out.(remote.NocopyWrite) - writeLen, err := frugal.EncodeObject(buf[offset:], nw, data) + _, err = frugal.EncodeObject(buf[offset:], nw, data) if err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Encode failed: %s", err.Error())) } - offset += writeLen - bthrift.Binary.WriteMessageEnd(buf[offset:]) if nw != nil { return nw.MallocAck(mallocLen) } diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index 469bf23c0d..9c4e2dab60 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -82,23 +82,14 @@ func initFrugalTagRecvMsg() remote.Message { } func TestHyperCodecCheck(t *testing.T) { - msg := initFrugalTagRecvMsg() - msg.SetPayloadLen(0) - codec := &thriftCodec{} - - // test CodecType check - test.Assert(t, codec.hyperMarshalEnabled() == false) - msg.SetPayloadLen(1) - test.Assert(t, codec.hyperMessageUnmarshalEnabled() == false) - msg.SetPayloadLen(0) - // test hyperMarshal check - codec = &thriftCodec{FrugalWrite} test.Assert(t, hyperMarshalAvailable(&MockNoTagArgs{}) == false) test.Assert(t, hyperMarshalAvailable(&MockFrugalTagArgs{}) == true) // test hyperMessageUnmarshal check - codec = &thriftCodec{FrugalRead} + msg := initFrugalTagRecvMsg() + msg.SetPayloadLen(0) + codec := &thriftCodec{FrugalRead} test.Assert(t, codec.hyperMessageUnmarshalAvailable(&MockNoTagArgs{}, msg.PayloadLen()) == false) test.Assert(t, codec.hyperMessageUnmarshalAvailable(&MockFrugalTagArgs{}, msg.PayloadLen()) == false) msg.SetPayloadLen(1) diff --git a/pkg/remote/codec/thrift/thrift_others.go b/pkg/remote/codec/thrift/thrift_others.go index 27a723cc1a..3895cd46f2 100644 --- a/pkg/remote/codec/thrift/thrift_others.go +++ b/pkg/remote/codec/thrift/thrift_others.go @@ -23,21 +23,11 @@ import ( "github.com/cloudwego/kitex/pkg/remote" ) -// hyperMarshalEnabled indicates that if there are high priority message codec for current platform. -func (c thriftCodec) hyperMarshalEnabled() bool { - return false -} - // hyperMarshalAvailable indicates that if high priority message codec is available. func hyperMarshalAvailable(data interface{}) bool { return false } -// hyperMessageUnmarshalEnabled indicates that if there are high priority message codec for current platform. -func (c thriftCodec) hyperMessageUnmarshalEnabled() bool { - return false -} - // hyperMessageUnmarshalAvailable indicates that if high priority message codec is available. func (c thriftCodec) hyperMessageUnmarshalAvailable(data interface{}, payloadLen int) bool { return false diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index d4c9f0eb9b..dd84489d99 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -21,13 +21,13 @@ import ( "errors" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/internal/mocks" mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -63,18 +63,18 @@ func init() { } type mockWithContext struct { - ReadFunc func(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error - WriteFunc func(ctx context.Context, method string, oprot thrift.TProtocol) error + ReadFunc func(ctx context.Context, method string, dataLen int, oprot athrift.TProtocol) error + WriteFunc func(ctx context.Context, method string, oprot athrift.TProtocol) error } -func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error { +func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot athrift.TProtocol) error { if m.ReadFunc != nil { return m.ReadFunc(ctx, method, dataLen, oprot) } return nil } -func (m *mockWithContext) Write(ctx context.Context, method string, oprot thrift.TProtocol) error { +func (m *mockWithContext) Write(ctx context.Context, method string, oprot athrift.TProtocol) error { if m.WriteFunc != nil { return m.WriteFunc(ctx, method, oprot) } @@ -86,7 +86,7 @@ func TestWithContext(t *testing.T) { t.Run(tb.Name, func(t *testing.T) { ctx := context.Background() - req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot thrift.TProtocol) error { + req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot athrift.TProtocol) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") @@ -99,7 +99,7 @@ func TestWithContext(t *testing.T) { buf.Flush() { - resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error { + resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot athrift.TProtocol) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") @@ -115,19 +115,42 @@ func TestWithContext(t *testing.T) { } func TestNormal(t *testing.T) { + // msg only supports FastCodec for _, tb := range transportBuffers { t.Run(tb.Name, func(t *testing.T) { ctx := context.Background() + // encode client side + sendMsg := initSendMsg(transport.TTHeader, false) + buf := tb.NewBuffer() + err := payloadCodec.Marshal(ctx, sendMsg, buf) + test.Assert(t, err == nil, err) + buf.Flush() + + // decode server side + recvMsg := initRecvMsg(false) + recvMsg.SetPayloadLen(buf.ReadableLen()) + test.Assert(t, err == nil, err) + err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + test.Assert(t, err == nil, err) + // compare Req Arg + compare(t, sendMsg, recvMsg) + }) + } + + // msg only supports Basic codec (apache codec) + for _, tb := range transportBuffers { + t.Run(tb.Name+"Basic", func(t *testing.T) { + ctx := context.Background() // encode client side - sendMsg := initSendMsg(transport.TTHeader) + sendMsg := initSendMsg(transport.TTHeader, true) buf := tb.NewBuffer() err := payloadCodec.Marshal(ctx, sendMsg, buf) test.Assert(t, err == nil, err) buf.Flush() // decode server side - recvMsg := initRecvMsg() + recvMsg := initRecvMsg(true) recvMsg.SetPayloadLen(buf.ReadableLen()) test.Assert(t, err == nil, err) err = payloadCodec.Unmarshal(ctx, recvMsg, buf) @@ -137,6 +160,30 @@ func TestNormal(t *testing.T) { compare(t, sendMsg, recvMsg) }) } + + // Exception case + for _, tb := range transportBuffers { + t.Run(tb.Name+"Ex", func(t *testing.T) { + ctx := context.Background() + // encode client side + sendMsg := newMsg(remote.NewTransErrorWithMsg(1, "hello")) + sendMsg.SetMessageType(remote.Exception) + buf := tb.NewBuffer() + err := payloadCodec.Marshal(ctx, sendMsg, buf) + test.Assert(t, err == nil, err) + buf.Flush() + + // decode server side + recvMsg := newMsg(nil) + recvMsg.SetPayloadLen(buf.ReadableLen()) + test.Assert(t, err == nil, err) + err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + test.Assert(t, err != nil) + te, ok := err.(*remote.TransError) + test.Assert(t, ok) + test.Assert(t, te.TypeID() == 1 && te.Error() == "hello", te) + }) + } } func BenchmarkNormalParallel(b *testing.B) { @@ -148,14 +195,14 @@ func BenchmarkNormalParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { // encode // client side - sendMsg := initSendMsg(transport.TTHeader) + sendMsg := initSendMsg(transport.TTHeader, false) buf := tb.NewBuffer() err := payloadCodec.Marshal(ctx, sendMsg, buf) test.Assert(b, err == nil, err) buf.Flush() // decode server side - recvMsg := initRecvMsg() + recvMsg := initRecvMsg(false) recvMsg.SetPayloadLen(buf.ReadableLen()) test.Assert(b, err == nil, err) err = payloadCodec.Unmarshal(ctx, recvMsg, buf) @@ -201,7 +248,7 @@ func TestException(t *testing.T) { err = payloadCodec.Unmarshal(ctx, recvMsg, buf) test.Assert(t, err != nil) transErr, ok := err.(*remote.TransError) - test.Assert(t, ok) + test.Assert(t, ok, err) test.Assert(t, err.Error() == errInfo) test.Assert(t, transErr.TypeID() == remote.UnknownMethod) }) @@ -210,13 +257,13 @@ func TestException(t *testing.T) { func TestTransErrorUnwrap(t *testing.T) { errMsg := "mock err" - transErr := remote.NewTransError(remote.InternalError, bthrift.NewApplicationException(1000, errMsg)) - uwErr, ok := transErr.Unwrap().(*bthrift.ApplicationException) + transErr := remote.NewTransError(remote.InternalError, thrift.NewApplicationException(1000, errMsg)) + uwErr, ok := transErr.Unwrap().(*thrift.ApplicationException) test.Assert(t, ok) test.Assert(t, uwErr.TypeId() == 1000) test.Assert(t, transErr.Error() == errMsg) - uwErr2, ok := errors.Unwrap(transErr).(*bthrift.ApplicationException) + uwErr2, ok := errors.Unwrap(transErr).(*thrift.ApplicationException) test.Assert(t, ok) test.Assert(t, uwErr2.TypeId() == 1000) test.Assert(t, uwErr2.Error() == errMsg) @@ -254,14 +301,14 @@ func TestSkipDecoder(t *testing.T) { for _, tb := range transportBuffers { t.Run(tc.desc+"#"+tb.Name, func(t *testing.T) { // encode client side - sendMsg := initSendMsg(tc.protocol) + sendMsg := initSendMsg(tc.protocol, true) // always use Basic to test skipdecodec buf := tb.NewBuffer() err := tc.codec.Marshal(context.Background(), sendMsg, buf) test.Assert(t, err == nil, err) buf.Flush() // decode server side - recvMsg := initRecvMsg() + recvMsg := initRecvMsg(true) if tc.protocol != transport.PurePayload { recvMsg.SetPayloadLen(buf.ReadableLen()) } @@ -275,22 +322,30 @@ func TestSkipDecoder(t *testing.T) { } } -func initSendMsg(tp transport.Protocol) remote.Message { - var _args mt.MockTestArgs - _args.Req = prepareReq() +func toApacheCodec(v bool, data thrift.FastCodec) interface{} { + if v { + return mt.ToApacheCodec(data) + } + return data +} + +func newMsg(data interface{}) remote.Message { ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) - msg := remote.NewMessage(mt.ToApacheCodec(&_args), svcInfo, ri, remote.Call, remote.Client) + return remote.NewMessage(data, svcInfo, ri, remote.Call, remote.Client) +} + +func initSendMsg(tp transport.Protocol, basic bool) remote.Message { + var _args mt.MockTestArgs // fastcodec only, if basic is true -> apachecodec + _args.Req = prepareReq() + msg := newMsg(toApacheCodec(basic, &_args)) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) return msg } -func initRecvMsg() remote.Message { - var _args mt.MockTestArgs - ink := rpcinfo.NewInvocation("", "mock") - ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) - msg := remote.NewMessage(mt.ToApacheCodec(&_args), svcInfo, ri, remote.Call, remote.Server) - return msg +func initRecvMsg(basic bool) remote.Message { + var _args mt.MockTestArgs // fastcodec only, if basic is true -> apachecodec + return newMsg(toApacheCodec(basic, &_args)) } func compare(t *testing.T, sendMsg, recvMsg remote.Message) { diff --git a/pkg/remote/trans/netpollmux/control_frame.go b/pkg/remote/trans/netpollmux/control_frame.go index 4c060a24d7..a913a656c6 100644 --- a/pkg/remote/trans/netpollmux/control_frame.go +++ b/pkg/remote/trans/netpollmux/control_frame.go @@ -25,8 +25,8 @@ package netpollmux import ( "fmt" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type ControlFrame struct{} @@ -37,8 +37,8 @@ func NewControlFrame() *ControlFrame { var fieldIDToName_ControlFrame = map[int16]string{} -func (p *ControlFrame) Read(iprot thrift.TProtocol) (err error) { - var fieldTypeId thrift.TType +func (p *ControlFrame) Read(iprot athrift.TProtocol) (err error) { + var fieldTypeId athrift.TType var fieldId int16 if _, err = iprot.ReadStructBegin(); err != nil { @@ -50,7 +50,7 @@ func (p *ControlFrame) Read(iprot thrift.TProtocol) (err error) { if err != nil { goto ReadFieldBeginError } - if fieldTypeId == thrift.STOP { + if fieldTypeId == athrift.STOP { break } if err = iprot.Skip(fieldTypeId); err != nil { @@ -67,19 +67,19 @@ func (p *ControlFrame) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) SkipFieldTypeError: - return bthrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) + return thrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) ReadFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } -func (p *ControlFrame) Write(oprot thrift.TProtocol) (err error) { +func (p *ControlFrame) Write(oprot athrift.TProtocol) (err error) { if err = oprot.WriteStructBegin("ControlFrame"); err != nil { goto WriteStructBeginError } @@ -93,11 +93,11 @@ func (p *ControlFrame) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldStopError: - return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *ControlFrame) String() string { diff --git a/pkg/utils/fastthrift/fastthrift.go b/pkg/utils/fastthrift/fastthrift.go index 0d27d849f1..498be4f2f1 100644 --- a/pkg/utils/fastthrift/fastthrift.go +++ b/pkg/utils/fastthrift/fastthrift.go @@ -17,66 +17,17 @@ package fastthrift import ( - "errors" - - "github.com/bytedance/gopkg/lang/dirtmake" - - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift" ) // FastMarshal marshals the msg to buf. The msg should be generated by Kitex tool and implement ThriftFastCodec. -func FastMarshal(msg bthrift.ThriftFastCodec) []byte { - sz := msg.BLength() - buf := dirtmake.Bytes(sz, sz) - msg.FastWriteNocopy(buf, nil) - return buf +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift.FastMarshal +func FastMarshal(msg thrift.FastCodec) []byte { + return thrift.FastMarshal(msg) } // FastUnmarshal unmarshal the buf into msg. The msg should be generated by Kitex tool and implement ThriftFastCodec. -func FastUnmarshal(buf []byte, msg bthrift.ThriftFastCodec) error { - _, err := msg.FastRead(buf) - return err -} - -// for msgType of MarshalMsg -// Please use theses consts instead of relying on apache thrift.TMessageType -const ( - CALL = uint8(1) - REPLY = uint8(2) - EXCEPTION = uint8(3) - ONEWAY = uint8(4) -) - -// MarshalMsg encodes the given msg to buf for generic thrift RPC. -func MarshalMsg(method string, msgType uint8, seq int32, msg bthrift.ThriftFastCodec) ([]byte, error) { - if method == "" { - return nil, errors.New("method not set") - } - sz := bthrift.Binary.MessageBeginLength(method, thrift.TMessageType(msgType), seq) + msg.BLength() - b := dirtmake.Bytes(sz, sz) - i := bthrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seq) - _ = msg.FastWriteNocopy(b[i:], nil) - return b, nil -} - -// UnmarshalMsg parses the given buf and stores the result to msg for generic thrift RPC. -// for EXCEPTION msgType, it will returns `err` with *bthrift.ApplicationException type without storing the result to msg. -func UnmarshalMsg(b []byte, msg bthrift.ThriftFastCodec) (method string, seq int32, err error) { - method, msgType, seq, i, err := bthrift.Binary.ReadMessageBegin(b) - if err != nil { - return "", 0, err - } - b = b[i:] - - if uint8(msgType) == EXCEPTION { - ex := bthrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") - _, err = ex.FastRead(b) - if err != nil { - return method, seq, err - } - return method, seq, ex - } - _, err = msg.FastRead(b) - return method, seq, err +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift.FastUnmarshal +func FastUnmarshal(buf []byte, msg thrift.FastCodec) error { + return thrift.FastUnmarshal(buf, msg) } diff --git a/pkg/utils/fastthrift/fastthrift_test.go b/pkg/utils/fastthrift/fastthrift_test.go index 31525da8fa..4b771ca124 100644 --- a/pkg/utils/fastthrift/fastthrift_test.go +++ b/pkg/utils/fastthrift/fastthrift_test.go @@ -21,8 +21,6 @@ import ( mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) var ( @@ -50,50 +48,3 @@ func TestFastThrift(t *testing.T) { test.Assert(t, req1.Msg == req2.Msg) test.Assert(t, len(req1.StrList) == len(req2.StrList)) } - -func TestMarshalMsg(t *testing.T) { - // CALL and REPLY - - req := &mocks.MockReq{} - req.Msg = "Hello" - b, err := MarshalMsg("Echo", CALL, 1, req) - test.Assert(t, err == nil, err) - - resp := &mocks.MockReq{} - method, seq, err := UnmarshalMsg(b, resp) - test.Assert(t, err == nil, err) - test.Assert(t, method == "Echo", method) - test.Assert(t, seq == 1, seq) - test.Assert(t, resp.Msg == req.Msg, resp.Msg) - - // EXCEPTION - - ex := bthrift.NewApplicationException(thrift.WRONG_METHOD_NAME, "Ex!") - b, err = MarshalMsg("ExMethod", EXCEPTION, 2, ex) - test.Assert(t, err == nil, err) - method, seq, err = UnmarshalMsg(b, nil) - test.Assert(t, err != nil) - test.Assert(t, method == "ExMethod") - test.Assert(t, seq == 2) - e, ok := err.(*bthrift.ApplicationException) - test.Assert(t, ok) - test.Assert(t, e.TypeID() == ex.TypeID() && e.Error() == ex.Error()) -} - -func BenchmarkFastUnmarshal(b *testing.B) { - buf := FastMarshal(newRequest()) - b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = FastUnmarshal(buf, &mocks.MockReq{}) - } -} - -func BenchmarkFastMarshal(b *testing.B) { - req := newRequest() - b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - FastMarshal(req) - } -} diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index f05f1d4612..96a2a645e1 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -20,21 +20,22 @@ import ( "errors" "fmt" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift" + + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // ThriftMessageCodec is used to codec thrift messages. type ThriftMessageCodec struct { - tb *thrift.TMemoryBuffer - tProt thrift.TProtocol + tb *athrift.TMemoryBuffer + tProt athrift.TProtocol } // NewThriftMessageCodec creates a new ThriftMessageCodec. func NewThriftMessageCodec() *ThriftMessageCodec { // TODO: use remote.ByteBuffer & remote/codec/thrift.BinaryProtocol - transport := thrift.NewTMemoryBufferLen(1024) - tProt := thrift.NewTBinaryProtocol(transport, true, true) + transport := athrift.NewTMemoryBufferLen(1024) + tProt := athrift.NewTBinaryProtocol(transport, true, true) return &ThriftMessageCodec{ tb: transport, @@ -46,7 +47,7 @@ func NewThriftMessageCodec() *ThriftMessageCodec { // Notice! msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result // Notice! seqID will be reset in kitex if the buffer is used for generic call in client side, set seqID=0 is suggested // when you call this method as client. -func (t *ThriftMessageCodec) Encode(method string, msgType thrift.TMessageType, seqID int32, msg thrift.TStruct) (b []byte, err error) { +func (t *ThriftMessageCodec) Encode(method string, msgType athrift.TMessageType, seqID int32, msg athrift.TStruct) (b []byte, err error) { if method == "" { return nil, errors.New("empty methodName in thrift RPCEncode") } @@ -65,24 +66,25 @@ func (t *ThriftMessageCodec) Encode(method string, msgType thrift.TMessageType, } // Decode do thrift message decode, notice: msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result -func (t *ThriftMessageCodec) Decode(b []byte, msg thrift.TStruct) (method string, seqID int32, err error) { +func (t *ThriftMessageCodec) Decode(b []byte, msg athrift.TStruct) (method string, seqID int32, err error) { t.tb.Reset() if _, err = t.tb.Write(b); err != nil { return } - var msgType thrift.TMessageType + var msgType athrift.TMessageType if method, msgType, seqID, err = t.tProt.ReadMessageBegin(); err != nil { return } - if msgType == thrift.EXCEPTION { - exception := bthrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") - if err = exception.Read(t.tProt); err != nil { + if msgType == athrift.EXCEPTION { + b = b[thrift.Binary.MessageBeginLength(method, 0, 0):] // for reusing fast read + ex := thrift.NewApplicationException(athrift.UNKNOWN_APPLICATION_EXCEPTION, "") + if _, err = ex.FastRead(b); err != nil { return } if err = t.tProt.ReadMessageEnd(); err != nil { return } - err = exception + err = ex return } if err = msg.Read(t.tProt); err != nil { @@ -94,7 +96,7 @@ func (t *ThriftMessageCodec) Decode(b []byte, msg thrift.TStruct) (method string // Serialize serialize message into bytes. This is normal thrift serialize func. // Notice: Binary generic use Encode instead of Serialize. -func (t *ThriftMessageCodec) Serialize(msg thrift.TStruct) (b []byte, err error) { +func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) (b []byte, err error) { t.tb.Reset() if err = msg.Write(t.tProt); err != nil { @@ -106,7 +108,7 @@ func (t *ThriftMessageCodec) Serialize(msg thrift.TStruct) (b []byte, err error) // Deserialize deserialize bytes into message. This is normal thrift deserialize func. // Notice: Binary generic use Decode instead of Deserialize. -func (t *ThriftMessageCodec) Deserialize(msg thrift.TStruct, b []byte) (err error) { +func (t *ThriftMessageCodec) Deserialize(msg athrift.TStruct, b []byte) (err error) { t.tb.Reset() if _, err = t.tb.Write(b); err != nil { return @@ -119,12 +121,12 @@ func (t *ThriftMessageCodec) Deserialize(msg thrift.TStruct, b []byte) (err erro // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. func MarshalError(method string, err error) []byte { - ex := bthrift.NewApplicationException(thrift.INTERNAL_ERROR, err.Error()) - n := bthrift.Binary.MessageBeginLength(method, thrift.EXCEPTION, 0) + ex := thrift.NewApplicationException(athrift.INTERNAL_ERROR, err.Error()) + n := thrift.Binary.MessageBeginLength(method, 0, 0) n += ex.BLength() b := make([]byte, n) // Write message header - off := bthrift.Binary.WriteMessageBegin(b, method, thrift.EXCEPTION, 0) + off := thrift.Binary.WriteMessageBegin(b, method, thrift.EXCEPTION, 0) // Write Ex body off += ex.FastWrite(b[off:]) return b[:off] @@ -133,7 +135,7 @@ func MarshalError(method string, err error) []byte { // UnmarshalError decode binary and return error message func UnmarshalError(b []byte) error { // Read message header - _, tp, _, l, err := bthrift.Binary.ReadMessageBegin(b) + _, tp, _, l, err := thrift.Binary.ReadMessageBegin(b) if err != nil { return err } @@ -142,7 +144,7 @@ func UnmarshalError(b []byte) error { } // Read Ex body off := l - ex := bthrift.NewApplicationException(thrift.INTERNAL_ERROR, "") + ex := thrift.NewApplicationException(athrift.INTERNAL_ERROR, "") if _, err := ex.FastRead(b[off:]); err != nil { return err } diff --git a/tool/internal_pkg/pluginmode/thriftgo/ast.go b/tool/internal_pkg/pluginmode/thriftgo/ast.go index c6f3779ee5..de610dfcf8 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/ast.go +++ b/tool/internal_pkg/pluginmode/thriftgo/ast.go @@ -43,16 +43,15 @@ func ZeroWriter(t *parser.Type, oprot, buf, offset string) string { return offsetTPL(oprot+".WriteBinary("+buf+", []byte{})", offset) case parser.Category_Map: return offsetTPL(oprot+".WriteMapBegin("+buf+", thrift."+golang.GetTypeIDConstant(t.GetKeyType())+ - ",thrift."+golang.GetTypeIDConstant(t.GetValueType())+",0)", offset) + offsetTPL(oprot+".WriteMapEnd("+buf+")", offset) + ",thrift."+golang.GetTypeIDConstant(t.GetValueType())+",0)", offset) case parser.Category_List: return offsetTPL(oprot+".WriteListBegin("+buf+", thrift."+golang.GetTypeIDConstant(t.GetValueType())+ - ",0)", offset) + offsetTPL(oprot+".WriteListEnd("+buf+")", offset) + ",0)", offset) case parser.Category_Set: return offsetTPL(oprot+".WriteSetBegin("+buf+", thrift."+golang.GetTypeIDConstant(t.GetValueType())+ - ",0)", offset) + offsetTPL(oprot+".WriteSetEnd("+buf+")", offset) + ",0)", offset) case parser.Category_Struct: - return offsetTPL(oprot+".WriteStructBegin("+buf+", \"\")", offset) + offsetTPL(oprot+".WriteFieldStop("+buf+")", offset) + - offsetTPL(oprot+".WriteStructEnd("+buf+")", offset) + return offsetTPL(oprot+".WriteFieldStop("+buf+")", offset) default: panic("unsupported type zero writer for" + t.Name) } @@ -61,33 +60,29 @@ func ZeroWriter(t *parser.Type, oprot, buf, offset string) string { func ZeroBLength(t *parser.Type, oprot, offset string) string { switch t.GetCategory() { case parser.Category_Bool: - return offsetTPL(oprot+".BoolLength(false)", offset) + return offsetTPL(oprot+".BoolLength()", offset) case parser.Category_Byte: - return offsetTPL(oprot+".ByteLength(0)", offset) + return offsetTPL(oprot+".ByteLength()", offset) case parser.Category_I16: - return offsetTPL(oprot+".I16Length(0)", offset) + return offsetTPL(oprot+".I16Length()", offset) case parser.Category_Enum, parser.Category_I32: - return offsetTPL(oprot+".I32Length(0)", offset) + return offsetTPL(oprot+".I32Length()", offset) case parser.Category_I64: - return offsetTPL(oprot+".I64Length(0)", offset) + return offsetTPL(oprot+".I64Length()", offset) case parser.Category_Double: - return offsetTPL(oprot+".DoubleLength(0)", offset) + return offsetTPL(oprot+".DoubleLength()", offset) case parser.Category_String: return offsetTPL(oprot+".StringLength(\"\")", offset) case parser.Category_Binary: - return offsetTPL(oprot+".BinaryLength([]byte{})", offset) + return offsetTPL(oprot+".BinaryLength(nil)", offset) case parser.Category_Map: - return offsetTPL(oprot+".MapBeginLength(thrift."+golang.GetTypeIDConstant(t.GetKeyType())+ - ",thrift."+golang.GetTypeIDConstant(t.GetValueType())+", 0)", offset) + offsetTPL(oprot+".MapEndLength()", offset) + return offsetTPL(oprot+".MapBeginLength()", offset) case parser.Category_List: - return offsetTPL(oprot+".ListBeginLength(thrift."+golang.GetTypeIDConstant(t.GetValueType())+ - ",0)", offset) + offsetTPL(oprot+".ListEndLength()", offset) + return offsetTPL(oprot+".ListBeginLength()", offset) case parser.Category_Set: - return offsetTPL(oprot+".SetBeginLength(thrift."+golang.GetTypeIDConstant(t.GetValueType())+ - ",0)", offset) + offsetTPL(oprot+".SetEndLength()", offset) + return offsetTPL(oprot+".SetBeginLength()", offset) case parser.Category_Struct: - return offsetTPL(oprot+".StructBeginLength(\"\")", offset) + offsetTPL(oprot+".FieldStopLength()", offset) + - offsetTPL(oprot+".StructEndLength()", offset) + return offsetTPL(oprot+".FieldStopLength()", offset) default: panic("unsupported type zero writer for" + t.Name) } diff --git a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go index f866c490c9..3de740d9c4 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go +++ b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go @@ -29,11 +29,7 @@ var ( _ = (*bytes.Buffer)(nil) {{- UseStdLibrary "bytes"}} _ = (*strings.Builder)(nil) {{- UseStdLibrary "strings"}} _ = reflect.Type(nil) {{- UseStdLibrary "reflect"}} - _ = thrift.TProtocol(nil) {{- UseLib (ImportPathTo "pkg/protocol/bthrift/apache") "thrift"}} - {{- if GenerateFastAPIs}} - {{- UseLib (ImportPathTo "pkg/protocol/bthrift") ""}} - _ = bthrift.BinaryWriter(nil) - {{- end}} + _ = thrift.STOP {{- UseLib "github.com/cloudwego/gopkg/protocol/thrift" ""}} {{- if GenerateDeepCopyAPIs}} {{- UseLib "github.com/cloudwego/kitex/pkg/utils" "kutils"}} {{- end}} diff --git a/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go index 7dd9fad917..41350ffc52 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go +++ b/tool/internal_pkg/pluginmode/thriftgo/struct_tpl.go @@ -52,12 +52,6 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { var isset{{.GoName}} bool = false {{- end}} {{- end}} - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - for { {{- if Features.KeepUnknownFields}} {{- if gt (len .Fields) 0}} @@ -65,7 +59,7 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { {{- end}} var beginOff int = offset {{- end}} - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -87,7 +81,7 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { isset{{.GoName}} = true {{- end}} } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -95,7 +89,7 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { } {{- end}}{{/* range .Fields */}} default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -105,19 +99,12 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { {{- end}}{{/* if Features.KeepUnknownFields */}} } {{- else -}} - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } {{- end}}{{/* if len(.Fields) > 0 */}} - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } - {{- if Features.KeepUnknownFields}} {{if gt (len .Fields) 0 -}} if isUnknownField { @@ -128,11 +115,6 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { {{- end}} {{- end}}{{/* if Features.KeepUnknownFields */}} } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError - } {{ $NeedRequiredFieldNotSetError := false }} {{- range .Fields}} {{- if .Requiredness.IsRequired}} @@ -144,23 +126,17 @@ func (p *{{$TypeName}}) FastRead(buf []byte) (int, error) { {{- end}} {{- end}} return offset, nil -ReadStructBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) {{- if gt (len .Fields) 0}} ReadFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_{{$TypeName}}[fieldId]), err) + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_{{$TypeName}}[fieldId]), err) {{- end}} SkipFieldError: - return offset, bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) {{- if $NeedRequiredFieldNotSetError }} RequiredFieldNotSetError: - return offset, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("required field %s is not set", fieldIDToName_{{$TypeName}}[fieldId])) + return offset, thrift.NewProtocolException(thrift.INVALID_DATA, fmt.Sprintf("required field %s is not set", fieldIDToName_{{$TypeName}}[fieldId])) {{- end}}{{/* if $NeedRequiredFieldNotSetError */}} } {{- end}}{{/* define "StructLikeFastRead" */}} @@ -185,7 +161,7 @@ func (p *{{$TypeName}}) FastReadField{{Str .ID}}(buf []byte) (int, error) { {{- $target}} = _field {{- if Features.WithFieldMask}} } else { - l, err := bthrift.Binary.Skip(buf[offset:], thrift.{{.Type | GetTypeIDConstant}}) + l, err := thrift.Binary.Skip(buf[offset:], thrift.{{.Type | GetTypeIDConstant}}) offset += l if err != nil { return offset, err @@ -242,7 +218,7 @@ func (p *{{$TypeName}}) FastWrite(buf []byte) int { const StructLikeFastWriteNocopy = ` {{define "StructLikeFastWriteNocopy"}} {{- $TypeName := .GoName}} -func (p *{{$TypeName}}) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *{{$TypeName}}) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 {{- if eq .Category "union"}} var c int @@ -252,18 +228,16 @@ func (p *{{$TypeName}}) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryW } } {{- end}} - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "{{.Name}}") if p != nil { {{- $reorderedFields := ReorderStructFields .Fields}} {{- range $reorderedFields}} - offset += p.fastWriteField{{Str .ID}}(buf[offset:], binaryWriter) + offset += p.fastWriteField{{Str .ID}}(buf[offset:], w) {{- end}} {{- if Features.KeepUnknownFields}} offset += copy(buf[offset:], p._unknownFields) {{- end}}{{/* if Features.KeepUnknownFields */}} } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset {{- if eq .Category "union"}} CountSetFieldsError: @@ -286,7 +260,6 @@ func (p *{{$TypeName}}) BLength() int { } } {{- end}} - l += bthrift.Binary.StructBeginLength("{{.Name}}") if p != nil { {{- range .Fields}} {{- $isBaseVal := .Type | IsBaseType}} @@ -296,8 +269,7 @@ func (p *{{$TypeName}}) BLength() int { l += len(p._unknownFields) {{- end}}{{/* if Features.KeepUnknownFields */}} } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l {{- if eq .Category "union"}} CountSetFieldsError: @@ -314,7 +286,7 @@ const StructLikeFastWriteField = ` {{- $FieldName := .GoName}} {{- $TypeID := .Type | GetTypeIDConstant }} {{- $isBaseVal := .Type | IsBaseType}} -func (p *{{$TypeName}}) fastWriteField{{Str .ID}}(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *{{$TypeName}}) fastWriteField{{Str .ID}}(buf []byte, w thrift.NocopyWriter) int { offset := 0 {{- if .Requiredness.IsOptional}} if p.{{.IsSetter}}() { @@ -328,16 +300,14 @@ func (p *{{$TypeName}}) fastWriteField{{Str .ID}}(buf []byte, binaryWriter bthri if {{if $isBaseVal}}_{{else}}fm{{end}}, ex := p._fieldmask.Field({{.ID}}); ex { {{- end}} {{- end}} - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "{{.Name}}", thrift.{{$TypeID}}, {{.ID}}) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.{{$TypeID}}, {{.ID}}) {{- $ctx := (MkRWCtx .).WithFieldMask "fm"}} {{- template "FieldFastWrite" $ctx}} - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) {{- if Features.WithFieldMask}} {{- if Features.FieldMaskZeroRequired}} } else { - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "{{.Name}}", thrift.{{$TypeID}}, {{.ID}}) - {{ ZeroWriter .Type "bthrift.Binary" "buf[offset:]" "offset" -}} - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.{{$TypeID}}, {{.ID}}) + {{ ZeroWriter .Type "thrift.Binary" "buf[offset:]" "offset" -}} } {{- else if not .Requiredness.IsRequired}} } @@ -373,16 +343,14 @@ func (p *{{$TypeName}}) field{{Str .ID}}Length() int { if {{if $isBaseVal}}_{{else}}fm{{end}}, ex := p._fieldmask.Field({{.ID}}); ex { {{- end}} {{- end}} - l += bthrift.Binary.FieldBeginLength("{{.Name}}", thrift.{{$TypeID}}, {{.ID}}) + l += thrift.Binary.FieldBeginLength() {{- $ctx := (MkRWCtx .).WithFieldMask "fm"}} {{- template "FieldLength" $ctx}} - l += bthrift.Binary.FieldEndLength() {{- if Features.WithFieldMask}} {{- if Features.FieldMaskZeroRequired}} } else { - l += bthrift.Binary.FieldBeginLength("{{.Name}}", thrift.{{$TypeID}}, {{.ID}}) - {{ ZeroBLength .Type "bthrift.Binary" "l" -}} - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + {{ ZeroBLength .Type "thrift.Binary" "l" -}} } {{- else if not .Requiredness.IsRequired}} } @@ -435,7 +403,7 @@ const FieldFastReadBaseType = ` {{- if .NeedDecl}} var {{.Target}} {{.TypeName}} {{- end}} - if v, l, err := bthrift.Binary.Read{{.TypeID}}(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.Read{{.TypeID}}(buf[offset:]); err != nil { return offset, err } else { offset += l @@ -445,13 +413,13 @@ const FieldFastReadBaseType = ` {{.Target}} = &tmp {{- else -}} {{.Target}} = &v - {{- end}} + {{- end -}} {{ else}} {{- if $DiffType}} {{.Target}} = {{.TypeName}}(v) - {{- else}} + {{- else -}} {{.Target}} = v - {{- end}} + {{- end -}} {{ end}} } {{- end}}{{/* define "FieldFastReadBaseType" */}} @@ -476,7 +444,7 @@ const FieldFastReadMap = ` {{- $isStrKey := .KeyCtx.Type | IsStrType -}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := "nfm"}} - _, _, size, l, err := bthrift.Binary.ReadMapBegin(buf[offset:]) + _, _, size, l, err := thrift.Binary.ReadMapBegin(buf[offset:]) offset += l if err != nil { return offset, err @@ -492,7 +460,7 @@ const FieldFastReadMap = ` {{- if Features.WithFieldMask}} {{- if $isIntKey}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(int({{$key}})); !ex { - l, err := bthrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) + l, err := thrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) offset += l if err != nil { return offset, err @@ -501,7 +469,7 @@ const FieldFastReadMap = ` } else { {{- else if $isStrKey}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Str(string({{$key}})); !ex { - l, err := bthrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) + l, err := thrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) offset += l if err != nil { return offset, err @@ -510,7 +478,7 @@ const FieldFastReadMap = ` } else { {{- else}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(0); !ex { - l, err := bthrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) + l, err := thrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) offset += l if err != nil { return offset, err @@ -537,11 +505,6 @@ const FieldFastReadMap = ` } {{- end}} } - if l, err := bthrift.Binary.ReadMapEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } {{- end}}{{/* define "FieldFastReadMap" */}} ` @@ -550,7 +513,7 @@ const FieldFastReadSet = ` {{- $isStructVal := .ValCtx.Type.Category.IsStructLike -}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} - _, size, l, err := bthrift.Binary.ReadSetBegin(buf[offset:]) + _, size, l, err := thrift.Binary.ReadSetBegin(buf[offset:]) offset += l if err != nil { return offset, err @@ -564,7 +527,7 @@ const FieldFastReadSet = ` {{- if Features.WithFieldMask}} {{- $curFieldMask = "nfm"}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { - l, err := bthrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) + l, err := thrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) offset += l if err != nil { return offset, err @@ -588,11 +551,6 @@ const FieldFastReadSet = ` } {{- end}} } - if l, err := bthrift.Binary.ReadSetEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } {{- end}}{{/* define "FieldFastReadSet" */}} ` @@ -601,7 +559,7 @@ const FieldFastReadList = ` {{- $isStructVal := .ValCtx.Type.Category.IsStructLike -}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} - _, size, l, err := bthrift.Binary.ReadListBegin(buf[offset:]) + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) offset += l if err != nil { return offset, err @@ -615,7 +573,7 @@ const FieldFastReadList = ` {{- if Features.WithFieldMask}} {{- $curFieldMask = "nfm"}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { - l, err := bthrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) + l, err := thrift.Binary.Skip(buf[offset:], thrift.{{.ValCtx.Type | GetTypeIDConstant}}) offset += l if err != nil { return offset, err @@ -639,11 +597,6 @@ const FieldFastReadList = ` } {{- end}} } - if l, err := bthrift.Binary.ReadListEnd(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - } {{- end}}{{/* define "FieldFastReadList" */}} ` @@ -827,7 +780,7 @@ const FieldFastWriteStructLike = ` {{.Target}}.Set_FieldMask({{.FieldMask}}) {{- end}} {{- end}} - offset += {{.Target}}.FastWriteNocopy(buf[offset:], binaryWriter) + offset += {{.Target}}.FastWriteNocopy(buf[offset:], w) {{- end}}{{/* define "FieldFastWriteStructLike" */}} ` @@ -851,9 +804,9 @@ const FieldFastWriteBaseType = ` {{- if .Type.Category.IsEnum}}{{$Value = printf "int32(%s)" $Value}}{{end}} {{- if .Type.Category.IsBinary}}{{$Value = printf "[]byte(%s)" $Value}}{{end}} {{- if IsBinaryOrStringType .Type}} - offset += bthrift.Binary.Write{{.TypeID}}Nocopy(buf[offset:], binaryWriter, {{$Value}}) + offset += thrift.Binary.Write{{.TypeID}}Nocopy(buf[offset:], w, {{$Value}}) {{- else}} - offset += bthrift.Binary.Write{{.TypeID}}(buf[offset:], {{$Value}}) + offset += thrift.Binary.Write{{.TypeID}}(buf[offset:], {{$Value}}) {{- end}} {{- end}}{{/* define "FieldFastWriteBaseType" */}} ` @@ -865,17 +818,16 @@ const FieldBaseTypeLength = ` {{- if .Type.Category.IsEnum}}{{$Value = printf "int32(%s)" $Value}}{{end}} {{- if .Type.Category.IsBinary}}{{$Value = printf "[]byte(%s)" $Value}}{{end}} {{- if IsBinaryOrStringType .Type}} - l += bthrift.Binary.{{.TypeID}}LengthNocopy({{$Value}}) + l += thrift.Binary.{{.TypeID}}LengthNocopy({{$Value}}) {{- else}} - l += bthrift.Binary.{{.TypeID}}Length({{$Value}}) + l += thrift.Binary.{{.TypeID}}Length() {{- end}} {{- end}}{{/* define "FieldBaseTypeLength" */}} ` const FieldFixedLengthTypeLength = ` {{define "FieldFixedLengthTypeLength"}} -{{- $Value := .Target -}} -bthrift.Binary.{{.TypeID}}Length({{TypeIDToGoType .TypeID}}({{$Value}})) +thrift.Binary.{{.TypeID}}Length() {{- end -}}{{/* define "FieldFixedLengthTypeLength" */}} ` @@ -910,9 +862,7 @@ const FieldFastWriteMap = ` {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := "nfm"}} mapBeginOffset := offset - offset += bthrift.Binary.MapBeginLength(thrift. - {{- .KeyCtx.Type | GetTypeIDConstant -}} - , thrift.{{- .ValCtx.Type | GetTypeIDConstant -}}, 0) + offset += thrift.Binary.MapBeginLength() var length int for k, v := range {{.Target}}{ {{- if Features.WithFieldMask}} @@ -939,11 +889,10 @@ const FieldFastWriteMap = ` } {{- end}} } - bthrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift. + thrift.Binary.WriteMapBegin(buf[mapBeginOffset:], thrift. {{- .KeyCtx.Type | GetTypeIDConstant -}} , thrift.{{- .ValCtx.Type | GetTypeIDConstant -}} , length) - offset += bthrift.Binary.WriteMapEnd(buf[offset:]) {{- end}}{{/* define "FieldFastWriteMap" */}} ` @@ -953,19 +902,13 @@ const FieldMapLength = ` {{- $isStrKey := .KeyCtx.Type | IsStrType -}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} - l += bthrift.Binary.MapBeginLength(thrift. - {{- .KeyCtx.Type | GetTypeIDConstant -}} - , thrift.{{- .ValCtx.Type | GetTypeIDConstant -}} - , len({{.Target}})) + l += thrift.Binary.MapBeginLength() {{- if and (not Features.WithFieldMask) (and (IsFixedLengthType .KeyCtx.Type) (IsFixedLengthType .ValCtx.Type))}} - var tmpK {{.KeyCtx.TypeName}} - var tmpV {{.ValCtx.TypeName}} - l += ({{- $ctx := .KeyCtx.WithTarget "tmpK" -}} - {{- template "FieldFixedLengthTypeLength" $ctx}} + - {{- $ctx := .ValCtx.WithTarget "tmpV" -}} - {{- template "FieldFixedLengthTypeLength" $ctx}}) * len({{.Target}}) + l += ({{- template "FieldFixedLengthTypeLength" .KeyCtx}} + + {{- template "FieldFixedLengthTypeLength" .ValCtx}}) * len({{.Target}}) {{- else}} for k, v := range {{.Target}}{ + _, _ = k, v {{- if Features.WithFieldMask}} {{- $curFieldMask = "nfm"}} {{- if $isIntKey}} @@ -991,7 +934,6 @@ const FieldMapLength = ` {{- end}} } {{- end}}{{/* if */}} - l += bthrift.Binary.MapEndLength() {{- end}}{{/* define "FieldMapLength" */}} ` @@ -1000,8 +942,7 @@ const FieldFastWriteSet = ` {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} setBeginOffset := offset - offset += bthrift.Binary.SetBeginLength(thrift. - {{- .ValCtx.Type | GetTypeIDConstant -}}, 0) + offset += thrift.Binary.SetBeginLength() {{template "ValidateSet" .}} var length int for {{if Features.WithFieldMask}}i{{else}}_{{end}}, v := range {{.Target}} { @@ -1018,10 +959,9 @@ const FieldFastWriteSet = ` } {{- end}} } - bthrift.Binary.WriteSetBegin(buf[setBeginOffset:], thrift. + thrift.Binary.WriteSetBegin(buf[setBeginOffset:], thrift. {{- .ValCtx.Type | GetTypeIDConstant -}} , length) - offset += bthrift.Binary.WriteSetEnd(buf[offset:]) {{- end}}{{/* define "FieldFastWriteSet" */}} ` @@ -1029,16 +969,13 @@ const FieldSetLength = ` {{define "FieldSetLength"}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} - l += bthrift.Binary.SetBeginLength(thrift. - {{- .ValCtx.Type | GetTypeIDConstant -}} - , len({{.Target}})) + l += thrift.Binary.SetBeginLength() {{template "ValidateSet" .}} {{- if and (not Features.WithFieldMask) (IsFixedLengthType .ValCtx.Type)}} - var tmpV {{.ValCtx.TypeName}} - l += {{- $ctx := .ValCtx.WithTarget "tmpV" -}} - {{- template "FieldFixedLengthTypeLength" $ctx -}} * len({{.Target}}) + l += {{- template "FieldFixedLengthTypeLength" .ValCtx -}} * len({{.Target}}) {{- else}} for {{if Features.WithFieldMask}}i{{else}}_{{end}}, v := range {{.Target}} { + _ = v {{- if Features.WithFieldMask}} {{- $curFieldMask = "nfm"}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { @@ -1052,7 +989,6 @@ const FieldSetLength = ` {{- end}} } {{- end}}{{/* if */}} - l += bthrift.Binary.SetEndLength() {{- end}}{{/* define "FieldSetLength" */}} ` @@ -1061,8 +997,7 @@ const FieldFastWriteList = ` {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} listBeginOffset := offset - offset += bthrift.Binary.ListBeginLength(thrift. - {{- .ValCtx.Type | GetTypeIDConstant -}}, 0) + offset += thrift.Binary.ListBeginLength() var length int for {{if Features.WithFieldMask}}i{{else}}_{{end}}, v := range {{.Target}} { {{- if Features.WithFieldMask}} @@ -1078,10 +1013,9 @@ const FieldFastWriteList = ` } {{- end}} } - bthrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift. + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift. {{- .ValCtx.Type | GetTypeIDConstant -}} , length) - offset += bthrift.Binary.WriteListEnd(buf[offset:]) {{- end}}{{/* define "FieldFastWriteList" */}} ` @@ -1089,15 +1023,12 @@ const FieldListLength = ` {{define "FieldListLength"}} {{- $isBaseVal := .ValCtx.Type | IsBaseType -}} {{- $curFieldMask := .FieldMask}} - l += bthrift.Binary.ListBeginLength(thrift. - {{- .ValCtx.Type | GetTypeIDConstant -}} - , len({{.Target}})) + l += thrift.Binary.ListBeginLength() {{- if and (not Features.WithFieldMask) (IsFixedLengthType .ValCtx.Type)}} - var tmpV {{.ValCtx.TypeName}} - l += {{- $ctx := .ValCtx.WithTarget "tmpV" -}} - {{- template "FieldFixedLengthTypeLength" $ctx -}} * len({{.Target}}) + l += {{- template "FieldFixedLengthTypeLength" .ValCtx -}} * len({{.Target}}) {{- else}} for {{if Features.WithFieldMask}}i{{else}}_{{end}}, v := range {{.Target}} { + _ = v {{- if Features.WithFieldMask}} {{- $curFieldMask = "nfm"}} if {{if $isBaseVal}}_{{else}}{{$curFieldMask}}{{end}}, ex := {{.FieldMask}}.Int(i); !ex { @@ -1111,7 +1042,6 @@ const FieldListLength = ` {{- end}} } {{- end}}{{/* if */}} - l += bthrift.Binary.ListEndLength() {{- end}}{{/* define "FieldListLength" */}} ` From e2ff8eeacaed30e61613bdfb927109885a77e50d Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:11:52 +0800 Subject: [PATCH 22/70] refactor(generic): remove apache thrift.TProtocol from generic (#1450) --- go.mod | 2 +- go.sum | 4 +- {pkg => internal}/generic/proto/json.go | 0 {pkg => internal}/generic/proto/json_test.go | 4 +- {pkg => internal}/generic/proto/protobuf.go | 0 {pkg => internal}/generic/proto/type.go | 0 {pkg => internal}/generic/thrift/http.go | 38 +- .../generic/thrift/http_fallback.go | 5 +- .../generic/thrift/http_go116plus_amd64.go | 46 +- {pkg => internal}/generic/thrift/http_pb.go | 19 +- {pkg => internal}/generic/thrift/json.go | 44 +- .../generic/thrift/json_fallback.go | 5 +- .../generic/thrift/json_go116plus_amd64.go | 90 +- {pkg => internal}/generic/thrift/read.go | 93 +- {pkg => internal}/generic/thrift/read_test.go | 382 +++---- {pkg => internal}/generic/thrift/struct.go | 18 +- {pkg => internal}/generic/thrift/thrift.go | 7 +- {pkg => internal}/generic/thrift/util.go | 2 +- {pkg => internal}/generic/thrift/util_test.go | 8 +- {pkg => internal}/generic/thrift/write.go | 391 +++---- .../generic/thrift/write_test.go | 750 +++----------- internal/mocks/generic/thrift.go | 9 +- pkg/generic/binarythrift_codec_test.go | 2 +- pkg/generic/descriptor/http_mapping.go | 2 +- pkg/generic/descriptor/render.go | 2 +- pkg/generic/descriptor/type.go | 2 +- pkg/generic/generic_service.go | 53 +- pkg/generic/generic_service_test.go | 28 +- pkg/generic/httppbthrift_codec.go | 4 +- pkg/generic/httppbthrift_codec_test.go | 6 +- pkg/generic/httpthrift_codec.go | 2 +- pkg/generic/httpthrift_codec_test.go | 10 +- pkg/generic/json_test/generic_test.go | 8 +- pkg/generic/jsonpb_codec.go | 2 +- pkg/generic/jsonpb_codec_test.go | 2 +- pkg/generic/jsonthrift_codec.go | 2 +- pkg/generic/jsonthrift_codec_test.go | 2 +- pkg/generic/mapthrift_codec.go | 2 +- pkg/generic/mapthrift_codec_test.go | 2 +- pkg/generic/pb_descriptor_provider.go | 2 +- pkg/generic/pbidl_provider.go | 2 +- pkg/generic/thrift/base.go | 962 ------------------ pkg/generic/thrift/parse.go | 3 +- pkg/remote/codec/thrift/thrift.go | 50 +- pkg/remote/codec/thrift/thrift_data.go | 24 +- pkg/remote/codec/thrift/thrift_data_test.go | 26 +- pkg/remote/codec/thrift/thrift_frugal_test.go | 4 +- pkg/remote/codec/thrift/thrift_test.go | 14 +- pkg/serviceinfo/serviceinfo.go | 4 - 49 files changed, 759 insertions(+), 2380 deletions(-) rename {pkg => internal}/generic/proto/json.go (100%) rename {pkg => internal}/generic/proto/json_test.go (97%) rename {pkg => internal}/generic/proto/protobuf.go (100%) rename {pkg => internal}/generic/proto/type.go (100%) rename {pkg => internal}/generic/thrift/http.go (85%) rename {pkg => internal}/generic/thrift/http_fallback.go (78%) rename {pkg => internal}/generic/thrift/http_go116plus_amd64.go (62%) rename {pkg => internal}/generic/thrift/http_pb.go (79%) rename {pkg => internal}/generic/thrift/json.go (83%) rename {pkg => internal}/generic/thrift/json_fallback.go (79%) rename {pkg => internal}/generic/thrift/json_go116plus_amd64.go (59%) rename {pkg => internal}/generic/thrift/read.go (78%) rename {pkg => internal}/generic/thrift/read_test.go (66%) rename {pkg => internal}/generic/thrift/struct.go (79%) rename {pkg => internal}/generic/thrift/thrift.go (80%) rename {pkg => internal}/generic/thrift/util.go (98%) rename {pkg => internal}/generic/thrift/util_test.go (95%) rename {pkg => internal}/generic/thrift/write.go (66%) rename {pkg => internal}/generic/thrift/write_test.go (71%) delete mode 100644 pkg/generic/thrift/base.go diff --git a/go.mod b/go.mod index 438de7dae9..d8f1283377 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.2.9 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.1.15 - github.com/cloudwego/gopkg v0.0.0-20240722090221-969ae87c75ac + github.com/cloudwego/gopkg v0.0.0-20240725095015-34d5327eebca github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 3c40301dd4..ab61497ef7 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.1.15 h1:LC55UJKhQPMFVjDPbE+LJcF7etZjSx6uokG1tk0wPK0= github.com/cloudwego/frugal v0.1.15/go.mod h1:26kU1r18vA8vRg12c66XPDlfv1GQHDbE1RpusipXfcI= -github.com/cloudwego/gopkg v0.0.0-20240722090221-969ae87c75ac h1:B7iK0zQ34wJkmNixXDHMHB+WrZJYadTAJSJkM21RZ6U= -github.com/cloudwego/gopkg v0.0.0-20240722090221-969ae87c75ac/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.0.0-20240725095015-34d5327eebca h1:xe6SuqnTHcqQlID29RG8gflr5pLKpffDJUusm7rZUPI= +github.com/cloudwego/gopkg v0.0.0-20240725095015-34d5327eebca/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.0.9/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= diff --git a/pkg/generic/proto/json.go b/internal/generic/proto/json.go similarity index 100% rename from pkg/generic/proto/json.go rename to internal/generic/proto/json.go diff --git a/pkg/generic/proto/json_test.go b/internal/generic/proto/json_test.go similarity index 97% rename from pkg/generic/proto/json_test.go rename to internal/generic/proto/json_test.go index ccac9c4388..748a17e5c7 100644 --- a/pkg/generic/proto/json_test.go +++ b/internal/generic/proto/json_test.go @@ -32,8 +32,8 @@ import ( ) var ( - example2IDLPath = "../jsonpb_test/idl/example2.proto" - example2ProtoPath = "../jsonpb_test/data/example2_pb.bin" + example2IDLPath = "../../../pkg/generic/jsonpb_test/idl/example2.proto" + example2ProtoPath = "../../../pkg/generic/jsonpb_test/data/example2_pb.bin" ) func TestRun(t *testing.T) { diff --git a/pkg/generic/proto/protobuf.go b/internal/generic/proto/protobuf.go similarity index 100% rename from pkg/generic/proto/protobuf.go rename to internal/generic/proto/protobuf.go diff --git a/pkg/generic/proto/type.go b/internal/generic/proto/type.go similarity index 100% rename from pkg/generic/proto/type.go rename to internal/generic/proto/type.go diff --git a/pkg/generic/thrift/http.go b/internal/generic/thrift/http.go similarity index 85% rename from pkg/generic/thrift/http.go rename to internal/generic/thrift/http.go index f72df178d3..fd31bfcbce 100644 --- a/pkg/generic/thrift/http.go +++ b/internal/generic/thrift/http.go @@ -19,18 +19,19 @@ package thrift import ( "context" "fmt" + "io" "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" jsoniter "github.com/json-iterator/go" "github.com/cloudwego/kitex/pkg/generic/descriptor" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" ) type HTTPReaderWriter struct { @@ -79,7 +80,7 @@ func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.O } // originalWrite ... -func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out athrift.TProtocol, msg interface{}, requestBase *Base) error { +func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out io.Writer, msg interface{}, requestBase *base.Base) error { req := msg.(*descriptor.HTTPRequest) if req.Body == nil && len(req.RawBody) != 0 { if err := customJson.Unmarshal(req.RawBody, &req.Body); err != nil { @@ -93,7 +94,12 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out athrift.TProto if !fn.HasRequestBase { requestBase = nil } - return wrapStructWriter(ctx, req, out, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}) + binaryWriter := thrift.NewBinaryWriter() + if err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}); err != nil { + return err + } + _, err = out.Write(binaryWriter.Bytes()) + return err } // ReadHTTPResponse implement of MessageReaderWithMethod @@ -131,22 +137,26 @@ func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options) { } // Read ... -func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in athrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { + buffer, ok := in.(remote.ByteBuffer) + if !ok { + return nil, perrors.NewProtocolErrorWithMsg("io.Reader should be ByteBuffer") + } + binaryReader := thrift.NewBinaryReader(buffer) + // fallback logic if !r.dynamicgoEnabled || dataLen == 0 { - return r.originalRead(ctx, method, in) - } - tProt, ok := in.(*cthrift.BinaryProtocol) - if !ok { - return nil, perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol") + return r.originalRead(ctx, method, binaryReader) } + + // dynamicgo logic // TODO: support exception field - _, _, id, err := in.ReadFieldBegin() + _, id, err := binaryReader.ReadFieldBegin() if err != nil { return nil, err } - fBeginLen := thrift.Binary.FieldBeginLength() - transBuf, err := tProt.ByteBuffer().ReadBinary(dataLen - fBeginLen) + bProt := &thrift.BinaryProtocol{} + transBuf, err := buffer.ReadBinary(dataLen - bProt.FieldBeginLength()) if err != nil { return nil, err } @@ -176,7 +186,7 @@ func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient boo return resp, nil } -func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in athrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in *thrift.BinaryReader) (interface{}, error) { fnDsc, err := r.svc.LookupFunctionByMethod(method) if err != nil { return nil, err diff --git a/pkg/generic/thrift/http_fallback.go b/internal/generic/thrift/http_fallback.go similarity index 78% rename from pkg/generic/thrift/http_fallback.go rename to internal/generic/thrift/http_fallback.go index 4bda30b81b..4fa510bbed 100644 --- a/pkg/generic/thrift/http_fallback.go +++ b/internal/generic/thrift/http_fallback.go @@ -21,11 +21,12 @@ package thrift import ( "context" + "io" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift/base" ) // Write ... -func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { return w.originalWrite(ctx, out, msg, requestBase) } diff --git a/pkg/generic/thrift/http_go116plus_amd64.go b/internal/generic/thrift/http_go116plus_amd64.go similarity index 62% rename from pkg/generic/thrift/http_go116plus_amd64.go rename to internal/generic/thrift/http_go116plus_amd64.go index e0701ad6b6..c825890d84 100644 --- a/pkg/generic/thrift/http_go116plus_amd64.go +++ b/internal/generic/thrift/http_go116plus_amd64.go @@ -22,21 +22,21 @@ package thrift import ( "context" "fmt" + "io" "unsafe" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/j2t" - "github.com/cloudwego/dynamicgo/thrift/base" + dbase "github.com/cloudwego/dynamicgo/thrift/base" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/cloudwego/kitex/pkg/generic/descriptor" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" ) // Write ... -func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { // fallback logic if !w.dynamicgoEnabled { return w.originalWrite(ctx, out, msg, requestBase) @@ -56,16 +56,14 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg requestBase = nil } if requestBase != nil { - base := (*base.Base)(unsafe.Pointer(requestBase)) + base := (*dbase.Base)(unsafe.Pointer(requestBase)) ctx = context.WithValue(ctx, conv.CtxKeyThriftReqBase, base) cv = j2t.NewBinaryConv(w.convOptsWithThriftBase) } else { cv = j2t.NewBinaryConv(w.convOpts) } - if err := out.WriteStructBegin(dynamicgoTypeDsc.Struct().Name()); err != nil { - return err - } + binaryWriter := thrift.NewBinaryWriter() ctx = context.WithValue(ctx, conv.CtxKeyHTTPRequest, req) body := req.GetBody() @@ -73,37 +71,21 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out thrift.TProtocol, msg defer mcache.Free(dbuf) for _, field := range dynamicgoTypeDsc.Struct().Fields() { - if err := out.WriteFieldBegin(field.Name(), field.Type().Type().ToThriftTType(), int16(field.ID())); err != nil { - return err - } + binaryWriter.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID())) // json []byte to thrift []byte if err := cv.DoInto(ctx, field.Type(), body, &dbuf); err != nil { return err } - - // WriteFieldEnd has no content - // if err := out.WriteFieldEnd(); err != nil { - // return err - // } } - - tProt, ok := out.(*cthrift.BinaryProtocol) - if !ok { - return perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol") - } - buf, err := tProt.ByteBuffer().Malloc(len(dbuf)) - if err != nil { - return err - } - // TODO: implement MallocAck() to achieve zero copy - copy(buf, dbuf) - - if err := out.WriteFieldStop(); err != nil { + if _, err := out.Write(binaryWriter.Bytes()); err != nil { return err } - if err := out.WriteStructEnd(); err != nil { + if _, err := out.Write(dbuf); err != nil { return err } - return nil + binaryWriter.Reset() + binaryWriter.WriteFieldStop() + _, err := out.Write(binaryWriter.Bytes()) + return err } diff --git a/pkg/generic/thrift/http_pb.go b/internal/generic/thrift/http_pb.go similarity index 79% rename from pkg/generic/thrift/http_pb.go rename to internal/generic/thrift/http_pb.go index 23f5aaa5f5..8e3e4c221c 100644 --- a/pkg/generic/thrift/http_pb.go +++ b/internal/generic/thrift/http_pb.go @@ -20,13 +20,15 @@ import ( "context" "errors" "fmt" + "io" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/dynamic" + "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/proto" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type HTTPPbReaderWriter struct { @@ -53,7 +55,7 @@ func NewWriteHTTPPbRequest(svc *descriptor.ServiceDescriptor, pbSvc *desc.Servic } // Write ... -func (w *WriteHTTPPbRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (w *WriteHTTPPbRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { req := msg.(*descriptor.HTTPRequest) fn, err := w.svc.Router.Lookup(req) if err != nil { @@ -75,7 +77,12 @@ func (w *WriteHTTPPbRequest) Write(ctx context.Context, out thrift.TProtocol, ms } req.GeneralBody = pbMsg - return wrapStructWriter(ctx, req, out, fn.Request, &writerOption{requestBase: requestBase}) + binaryWriter := thrift.NewBinaryWriter() + if err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase}); err != nil { + return err + } + _, err = out.Write(binaryWriter.Bytes()) + return err } // ReadHTTPPbResponse implement of MessageReaderWithMethod @@ -93,7 +100,7 @@ func NewReadHTTPPbResponse(svc *descriptor.ServiceDescriptor, pbSvc proto.Servic } // Read ... -func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { +func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { fnDsc, err := r.svc.LookupFunctionByMethod(method) if err != nil { return nil, err @@ -104,5 +111,5 @@ func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient b return nil, errors.New("pb method not found") } - return skipStructReader(ctx, in, fDsc, &readerOption{pbDsc: mt.GetOutputType(), http: true}) + return skipStructReader(ctx, thrift.NewBinaryReader(in), fDsc, &readerOption{pbDsc: mt.GetOutputType(), http: true}) } diff --git a/pkg/generic/thrift/json.go b/internal/generic/thrift/json.go similarity index 83% rename from pkg/generic/thrift/json.go rename to internal/generic/thrift/json.go index c697c29f15..d4eb7bf082 100644 --- a/pkg/generic/thrift/json.go +++ b/internal/generic/thrift/json.go @@ -19,19 +19,21 @@ package thrift import ( "context" "fmt" + "io" "strconv" "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" jsoniter "github.com/json-iterator/go" "github.com/tidwall/gjson" "github.com/cloudwego/kitex/pkg/generic/descriptor" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/utils" ) @@ -81,7 +83,7 @@ func (m *WriteJSON) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) m.dynamicgoEnabled = true } -func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) if err != nil { return fmt.Errorf("missing method: %s in service: %s", method, m.svcDsc.Name) @@ -96,9 +98,15 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg requestBase = nil } + binaryWriter := thrift.NewBinaryWriter() + // msg is void or nil if _, ok := msg.(descriptor.Void); ok || msg == nil { - return wrapStructWriter(ctx, msg, out, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}) + if err = wrapStructWriter(ctx, msg, binaryWriter, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil { + return err + } + _, err = out.Write(binaryWriter.Bytes()) + return err } // msg is string @@ -117,7 +125,11 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg Index: 0, } } - return wrapJSONWriter(ctx, &body, out, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}) + if err = wrapJSONWriter(ctx, &body, binaryWriter, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil { + return err + } + _, err = out.Write(binaryWriter.Bytes()) + return err } // NewReadJSON build ReadJSON according to ServiceDescriptor @@ -154,16 +166,16 @@ func (m *ReadJSON) SetDynamicGo(convOpts, convOptsWithException *conv.Options) { } // Read read data from in thrift.TProtocol and convert to json string -func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { - // fallback logic - if !m.dynamicgoEnabled || dataLen <= 0 { - return m.originalRead(ctx, method, isClient, in) +func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { + buffer, ok := in.(remote.ByteBuffer) + if !ok { + return nil, perrors.NewProtocolErrorWithMsg("io.Reader should be ByteBuffer") } + binaryReader := thrift.NewBinaryReader(in) - // dynamicgo logic - tProt, ok := in.(*cthrift.BinaryProtocol) - if !ok { - return nil, perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol") + // fallback logic + if !m.dynamicgoEnabled || dataLen <= 0 { + return m.originalRead(ctx, method, isClient, binaryReader) } fnDsc := m.svc.DynamicGoDsc.Functions()[method] @@ -177,12 +189,12 @@ func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataL var resp interface{} if tyDsc.Struct().Fields()[0].Type().Type() == dthrift.VOID { - if _, err := tProt.ByteBuffer().ReadBinary(voidWholeLen); err != nil { + if _, err := buffer.ReadBinary(voidWholeLen); err != nil { return nil, err } resp = descriptor.Void{} } else { - transBuff, err := tProt.ByteBuffer().ReadBinary(dataLen) + transBuff, err := buffer.ReadBinary(dataLen) if err != nil { return nil, err } @@ -213,7 +225,7 @@ func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataL return resp, nil } -func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, in thrift.TProtocol) (interface{}, error) { +func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, in *thrift.BinaryReader) (interface{}, error) { fnDsc, err := m.svc.LookupFunctionByMethod(method) if err != nil { return nil, err diff --git a/pkg/generic/thrift/json_fallback.go b/internal/generic/thrift/json_fallback.go similarity index 79% rename from pkg/generic/thrift/json_fallback.go rename to internal/generic/thrift/json_fallback.go index d4a6a72e13..e1500fa853 100644 --- a/pkg/generic/thrift/json_fallback.go +++ b/internal/generic/thrift/json_fallback.go @@ -21,11 +21,12 @@ package thrift import ( "context" + "io" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift/base" ) // Write write json string to out thrift.TProtocol -func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { return m.originalWrite(ctx, out, msg, method, isClient, requestBase) } diff --git a/pkg/generic/thrift/json_go116plus_amd64.go b/internal/generic/thrift/json_go116plus_amd64.go similarity index 59% rename from pkg/generic/thrift/json_go116plus_amd64.go rename to internal/generic/thrift/json_go116plus_amd64.go index 0c1f7b9210..d321ea8afc 100644 --- a/pkg/generic/thrift/json_go116plus_amd64.go +++ b/internal/generic/thrift/json_go116plus_amd64.go @@ -22,23 +22,24 @@ package thrift import ( "context" "fmt" + "io" "unsafe" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/j2t" dthrift "github.com/cloudwego/dynamicgo/thrift" - "github.com/cloudwego/dynamicgo/thrift/base" + dbase "github.com/cloudwego/dynamicgo/thrift/base" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/cloudwego/kitex/pkg/generic/descriptor" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/utils" ) // Write write json string to out thrift.TProtocol -func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { // fallback logic if !m.dynamicgoEnabled { return m.originalWrite(ctx, out, msg, method, isClient, requestBase) @@ -60,22 +61,23 @@ func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interfa requestBase = nil } if requestBase != nil { - base := (*base.Base)(unsafe.Pointer(requestBase)) + base := (*dbase.Base)(unsafe.Pointer(requestBase)) ctx = context.WithValue(ctx, conv.CtxKeyThriftReqBase, base) cv = j2t.NewBinaryConv(m.convOptsWithThriftBase) } else { cv = j2t.NewBinaryConv(m.convOpts) } + binaryWriter := thrift.NewBinaryWriter() + // msg is void or nil if _, ok := msg.(descriptor.Void); ok || msg == nil { - if err := m.writeHead(out, dynamicgoTypeDsc); err != nil { - return err - } if err := m.writeFields(ctx, out, dynamicgoTypeDsc, nil, nil, isClient); err != nil { return err } - return writeTail(out) + binaryWriter.WriteFieldStop() + _, err := out.Write(binaryWriter.Bytes()) + return err } // msg is string @@ -85,26 +87,22 @@ func (m *WriteJSON) Write(ctx context.Context, out thrift.TProtocol, msg interfa } transBuff := utils.StringToSliceByte(s) - if err := m.writeHead(out, dynamicgoTypeDsc); err != nil { - return err - } if err := m.writeFields(ctx, out, dynamicgoTypeDsc, &cv, transBuff, isClient); err != nil { return err } - return writeTail(out) + binaryWriter.WriteFieldStop() + _, err := out.Write(binaryWriter.Bytes()) + return err } type MsgType int -const ( - Void MsgType = iota - String -) - -func (m *WriteJSON) writeFields(ctx context.Context, out thrift.TProtocol, dynamicgoTypeDsc *dthrift.TypeDescriptor, cv *j2t.BinaryConv, transBuff []byte, isClient bool) error { +func (m *WriteJSON) writeFields(ctx context.Context, out io.Writer, dynamicgoTypeDsc *dthrift.TypeDescriptor, cv *j2t.BinaryConv, transBuff []byte, isClient bool) error { dbuf := mcache.Malloc(len(transBuff))[0:0] defer mcache.Free(dbuf) + binaryWriter := thrift.NewBinaryWriter() + for _, field := range dynamicgoTypeDsc.Struct().Fields() { // Exception field if !isClient && field.ID() != 0 { @@ -113,15 +111,12 @@ func (m *WriteJSON) writeFields(ctx context.Context, out thrift.TProtocol, dynam continue } - if err := out.WriteFieldBegin(field.Name(), field.Type().Type().ToThriftTType(), int16(field.ID())); err != nil { - return err - } + binaryWriter.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID())) // if the field type is void, just write void and return if field.Type().Type() == dthrift.VOID { - if err := writeFieldForVoid(field.Name(), out); err != nil { - return err - } - return nil + binaryWriter.WriteFieldStop() + _, err := out.Write(binaryWriter.Bytes()) + return err } else { // encode using dynamicgo // json []byte to thrift []byte @@ -129,47 +124,10 @@ func (m *WriteJSON) writeFields(ctx context.Context, out thrift.TProtocol, dynam return err } } - // WriteFieldEnd has no content - // if err := out.WriteFieldEnd(); err != nil { - // return err - // } - } - tProt, ok := out.(*cthrift.BinaryProtocol) - if !ok { - return perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol") - } - buf, err := tProt.ByteBuffer().Malloc(len(dbuf)) - if err != nil { - return err - } - // TODO: implement MallocAck() to achieve zero copy - copy(buf, dbuf) - return nil -} - -func (m *WriteJSON) writeHead(out thrift.TProtocol, dynamicgoTypeDsc *dthrift.TypeDescriptor) error { - if err := out.WriteStructBegin(dynamicgoTypeDsc.Struct().Name()); err != nil { - return err - } - return nil -} - -func writeTail(out thrift.TProtocol) error { - if err := out.WriteFieldStop(); err != nil { - return err - } - return out.WriteStructEnd() -} - -func writeFieldForVoid(name string, out thrift.TProtocol) error { - if err := out.WriteStructBegin(name); err != nil { - return err - } - if err := out.WriteFieldStop(); err != nil { - return err } - if err := out.WriteStructEnd(); err != nil { + if _, err := out.Write(binaryWriter.Bytes()); err != nil { return err } - return nil + _, err := out.Write(dbuf) + return err } diff --git a/pkg/generic/thrift/read.go b/internal/generic/thrift/read.go similarity index 78% rename from pkg/generic/thrift/read.go rename to internal/generic/thrift/read.go index 649b349416..4b4bb3378c 100644 --- a/pkg/generic/thrift/read.go +++ b/internal/generic/thrift/read.go @@ -22,11 +22,11 @@ import ( "fmt" "reflect" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/jhump/protoreflect/desc" + "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/proto" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) var emptyPbDsc = &desc.MessageDescriptor{} @@ -48,7 +48,7 @@ type readerOption struct { pbDsc proto.MessageDescriptor } -type reader func(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) +type reader func(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) type fieldSetter func(field *descriptor.FieldDescriptor, val interface{}) error @@ -108,18 +108,15 @@ func nextReader(tt descriptor.Type, t *descriptor.TypeDescriptor, opt *readerOpt } } -func skipStructReader(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { - structName, err := in.ReadStructBegin() - if err != nil { - return nil, err - } +// TODO(marina.sakai): Optimize generic reader +func skipStructReader(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { var v interface{} for { - fieldName, fieldType, fieldID, err := in.ReadFieldBegin() + fieldType, fieldID, err := in.ReadFieldBegin() if err != nil { return nil, err } - if fieldType == thrift.STOP { + if fieldType == descriptor.STOP.ToThriftTType() { break } field, ok := t.Struct.FieldsByID[int32(fieldID)] @@ -132,7 +129,7 @@ func skipStructReader(ctx context.Context, in thrift.TProtocol, t *descriptor.Ty _fieldType := descriptor.FromThriftTType(fieldType) reader, err := nextReader(_fieldType, field.Type, opt) if err != nil { - return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", structName, fieldName, fieldID, err) + return nil, fmt.Errorf("nextReader of %s/%d error %w", field.Name, fieldID, err) } if field.IsException && opt != nil && opt.throwException { if v, err = reader(ctx, in, field.Type, opt); err != nil { @@ -147,31 +144,28 @@ func skipStructReader(ctx context.Context, in thrift.TProtocol, t *descriptor.Ty reader = readHTTPResponse } if v, err = reader(ctx, in, field.Type, opt); err != nil { - return nil, fmt.Errorf("reader of %s/%s/%d error %w", structName, fieldName, fieldID, err) + return nil, fmt.Errorf("reader of %s/%d error %w", field.Name, fieldID, err) } } - if err := in.ReadFieldEnd(); err != nil { - return nil, err - } } - return v, in.ReadStructEnd() + return v, nil } -func readVoid(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readVoid(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { _, err := readStruct(ctx, in, t, opt) return descriptor.Void{}, err } -func readDouble(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readDouble(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadDouble() } -func readBool(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readBool(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadBool() } -func readByte(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readByte(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { res, err := in.ReadByte() if err != nil { return nil, err @@ -182,7 +176,7 @@ func readByte(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescri return res, nil } -func readInt16(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInt16(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { res, err := in.ReadI16() if err != nil { return nil, err @@ -193,19 +187,19 @@ func readInt16(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescr return res, nil } -func readInt32(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInt32(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadI32() } -func readInt64(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInt64(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadI64() } -func readString(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readString(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadString() } -func readBinary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readBinary(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { bytes, err := in.ReadBinary() if err != nil { return "", err @@ -213,7 +207,7 @@ func readBinary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDesc return bytes, nil } -func readBase64Binary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readBase64Binary(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { bytes, err := in.ReadBinary() if err != nil { return "", err @@ -221,7 +215,7 @@ func readBase64Binary(ctx context.Context, in thrift.TProtocol, t *descriptor.Ty return base64.StdEncoding.EncodeToString(bytes), nil } -func readList(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readList(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { elemType, length, err := in.ReadListBegin() if err != nil { return nil, err @@ -239,17 +233,17 @@ func readList(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescri } l = append(l, item) } - return l, in.ReadListEnd() + return l, nil } -func readMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { if opt != nil && opt.forJSON { return readStringMap(ctx, in, t, opt) } return readInterfaceMap(ctx, in, t, opt) } -func readInterfaceMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInterfaceMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { keyType, elemType, length, err := in.ReadMapBegin() if err != nil { return nil, err @@ -283,10 +277,10 @@ func readInterfaceMap(ctx context.Context, in thrift.TProtocol, t *descriptor.Ty nest() m[key] = elem } - return m, in.ReadMapEnd() + return m, nil } -func readStringMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readStringMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { keyType, elemType, length, err := in.ReadMapBegin() if err != nil { return nil, err @@ -316,10 +310,10 @@ func readStringMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeD } m[buildinTypeIntoString(key)] = elem } - return m, in.ReadMapEnd() + return m, nil } -func readStruct(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readStruct(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { var fs fieldSetter var st interface{} if opt == nil || opt.pbDsc == nil { @@ -366,20 +360,16 @@ func readStruct(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDesc } } } - _, err = in.ReadStructBegin() if err != nil { return nil, err } readFields := map[int32]struct{}{} for { - _, fieldType, fieldID, err := in.ReadFieldBegin() + fieldType, fieldID, err := in.ReadFieldBegin() if err != nil { return nil, err } - if fieldType == thrift.STOP { - if err := in.ReadFieldEnd(); err != nil { - return nil, err - } + if fieldType == descriptor.STOP.ToThriftTType() { // check required // void is nil struct if t.Struct != nil { @@ -387,7 +377,7 @@ func readStruct(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDesc return nil, err } } - return st, in.ReadStructEnd() + return st, nil } field, ok := t.Struct.FieldsByID[int32(fieldID)] if !ok { @@ -417,14 +407,11 @@ func readStruct(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDesc return nil, err } } - if err := in.ReadFieldEnd(); err != nil { - return nil, err - } readFields[int32(fieldID)] = struct{}{} } } -func readHTTPResponse(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readHTTPResponse(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { var resp *descriptor.HTTPResponse if opt == nil || opt.pbDsc == nil { if opt == nil { @@ -449,25 +436,18 @@ func readHTTPResponse(ctx context.Context, in thrift.TProtocol, t *descriptor.Ty return nil, err } } - _, err = in.ReadStructBegin() - if err != nil { - return nil, err - } readFields := map[int32]struct{}{} for { - _, fieldType, fieldID, err := in.ReadFieldBegin() + fieldType, fieldID, err := in.ReadFieldBegin() if err != nil { return nil, err } - if fieldType == thrift.STOP { - if err := in.ReadFieldEnd(); err != nil { - return nil, err - } + if fieldType == descriptor.STOP.ToThriftTType() { // check required if err := t.Struct.CheckRequired(readFields); err != nil { return nil, err } - return resp, in.ReadStructEnd() + return resp, nil } field, ok := t.Struct.FieldsByID[int32(fieldID)] if !ok { @@ -499,9 +479,6 @@ func readHTTPResponse(ctx context.Context, in thrift.TProtocol, t *descriptor.Ty return nil, err } } - if err := in.ReadFieldEnd(); err != nil { - return nil, err - } readFields[int32(fieldID)] = struct{}{} } } diff --git a/pkg/generic/thrift/read_test.go b/internal/generic/thrift/read_test.go similarity index 66% rename from pkg/generic/thrift/read_test.go rename to internal/generic/thrift/read_test.go index 52f5e249db..e50030d33a 100644 --- a/pkg/generic/thrift/read_test.go +++ b/internal/generic/thrift/read_test.go @@ -19,18 +19,17 @@ package thrift import ( "context" "encoding/base64" - "errors" "fmt" "reflect" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/jhump/protoreflect/desc/protoparse" "github.com/stretchr/testify/require" - "github.com/cloudwego/kitex/internal/mocks" + "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/proto" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/remote" ) var ( @@ -69,12 +68,10 @@ func Test_nextReader(t *testing.T) { func Test_readVoid(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{} tests := []struct { name string args args @@ -82,11 +79,17 @@ func Test_readVoid(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"void", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.VOID}}, descriptor.Void{}, false}, + {"void", args{t: &descriptor.TypeDescriptor{Type: descriptor.VOID, Struct: &descriptor.StructDescriptor{}}}, descriptor.Void{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readVoid(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeVoid(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeVoid() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readVoid(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readVoid() error = %v, wantErr %v", err, tt.wantErr) return @@ -100,15 +103,9 @@ func Test_readVoid(t *testing.T) { func Test_readDouble(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadDoubleFunc: func() (value float64, err error) { - return 1.0, nil - }, - } tests := []struct { name string args args @@ -116,11 +113,17 @@ func Test_readDouble(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readDouble", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.DOUBLE}}, 1.0, false}, + {"readDouble", args{t: &descriptor.TypeDescriptor{Type: descriptor.DOUBLE}}, 1.0, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readDouble(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeFloat64(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeFloat64() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readDouble(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readDouble() error = %v, wantErr %v", err, tt.wantErr) return @@ -134,13 +137,9 @@ func Test_readDouble(t *testing.T) { func Test_readBool(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadBoolFunc: func() (bool, error) { return true, nil }, - } tests := []struct { name string args args @@ -148,11 +147,17 @@ func Test_readBool(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readBool", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.BOOL}}, true, false}, + {"readBool", args{t: &descriptor.TypeDescriptor{Type: descriptor.BOOL}}, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readBool(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeBool(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeBool() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readBool(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readBool() error = %v, wantErr %v", err, tt.wantErr) return @@ -166,15 +171,9 @@ func Test_readBool(t *testing.T) { func Test_readByte(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadByteFunc: func() (int8, error) { - return 1, nil - }, - } tests := []struct { name string args args @@ -182,11 +181,17 @@ func Test_readByte(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readByte", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.BYTE}, opt: &readerOption{}}, int8(1), false}, + {"readByte", args{t: &descriptor.TypeDescriptor{Type: descriptor.BYTE}, opt: &readerOption{}}, int8(1), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readByte(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeInt8(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeInt8() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readByte(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readByte() error = %v, wantErr %v", err, tt.wantErr) return @@ -200,15 +205,9 @@ func Test_readByte(t *testing.T) { func Test_readInt16(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadI16Func: func() (int16, error) { - return 1, nil - }, - } tests := []struct { name string args args @@ -216,11 +215,17 @@ func Test_readInt16(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readInt16", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.I16}, opt: &readerOption{}}, int16(1), false}, + {"readInt16", args{t: &descriptor.TypeDescriptor{Type: descriptor.I16}, opt: &readerOption{}}, int16(1), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readInt16(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeInt16(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeInt16() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readInt16(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readInt16() error = %v, wantErr %v", err, tt.wantErr) return @@ -234,16 +239,10 @@ func Test_readInt16(t *testing.T) { func Test_readInt32(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadI32Func: func() (int32, error) { - return 1, nil - }, - } tests := []struct { name string args args @@ -251,11 +250,17 @@ func Test_readInt32(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readInt32", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.I32}}, int32(1), false}, + {"readInt32", args{t: &descriptor.TypeDescriptor{Type: descriptor.I32}}, int32(1), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readInt32(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeInt32(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeInt32() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readInt32(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readInt32() error = %v, wantErr %v", err, tt.wantErr) return @@ -269,15 +274,9 @@ func Test_readInt32(t *testing.T) { func Test_readInt64(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadI64Func: func() (int64, error) { - return 1, nil - }, - } tests := []struct { name string args args @@ -285,11 +284,17 @@ func Test_readInt64(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readInt64", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.I64}}, int64(1), false}, + {"readInt64", args{t: &descriptor.TypeDescriptor{Type: descriptor.I64}}, int64(1), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readInt64(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeInt64(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeInt64() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readInt64(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readInt64() error = %v, wantErr %v", err, tt.wantErr) return @@ -303,18 +308,9 @@ func Test_readInt64(t *testing.T) { func Test_readString(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadStringFunc: func() (string, error) { - return stringInput, nil - }, - ReadBinaryFunc: func() ([]byte, error) { - return binaryInput, nil - }, - } tests := []struct { name string args args @@ -322,11 +318,17 @@ func Test_readString(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readString", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, stringInput, false}, + {"readString", args{t: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, stringInput, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readString(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeString(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeString() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readString(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readString() error = %v, wantErr %v", err, tt.wantErr) return @@ -340,18 +342,9 @@ func Test_readString(t *testing.T) { func Test_readBinary64String(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadStringFunc: func() (string, error) { - return stringInput, nil - }, - ReadBinaryFunc: func() ([]byte, error) { - return binaryInput, nil - }, - } tests := []struct { name string args args @@ -359,11 +352,17 @@ func Test_readBinary64String(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readBase64Binary", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, base64.StdEncoding.EncodeToString(binaryInput), false}, // read base64 string from binary field + {"readBase64Binary", args{t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, base64.StdEncoding.EncodeToString(binaryInput), false}, // read base64 string from binary field } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readBase64Binary(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeBase64Binary(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeBase64Binary() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readBase64Binary(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readString() error = %v, wantErr %v", err, tt.wantErr) return @@ -377,18 +376,9 @@ func Test_readBinary64String(t *testing.T) { func Test_readBinary(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadStringFunc: func() (string, error) { - return stringInput, nil - }, - ReadBinaryFunc: func() ([]byte, error) { - return binaryInput, nil - }, - } tests := []struct { name string args args @@ -396,11 +386,17 @@ func Test_readBinary(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readBinary", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, binaryInput, false}, // read base64 string from binary field + {"readBinary", args{t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, binaryInput, false}, // read base64 string from binary field } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readBinary(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeBinary(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeBinary() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readBinary(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readString() error = %v, wantErr %v", err, tt.wantErr) return @@ -414,19 +410,9 @@ func Test_readBinary(t *testing.T) { func Test_readList(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - mockTTransport := &mocks.MockThriftTTransport{ - ReadListBeginFunc: func() (elemType thrift.TType, size int, err error) { - return thrift.STRING, 3, nil - }, - - ReadStringFunc: func() (string, error) { - return stringInput, nil - }, - } tests := []struct { name string args args @@ -434,11 +420,17 @@ func Test_readList(t *testing.T) { wantErr bool }{ // TODO: Add test cases. - {"readList", args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}}, []interface{}{stringInput, stringInput, stringInput}, false}, + {"readList", args{t: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}}, []interface{}{stringInput, stringInput, stringInput}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readList(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeList(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeList() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readList(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readList() error = %v, wantErr %v", err, tt.wantErr) return @@ -452,60 +444,48 @@ func Test_readList(t *testing.T) { func Test_readMap(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - count := 0 - mockTTransport := &mocks.MockThriftTTransport{ - ReadMapBeginFunc: func() (keyType, valueType thrift.TType, size int, err error) { - return thrift.STRING, thrift.STRING, 1, nil - }, - ReadStringFunc: func() (string, error) { - defer func() { count++ }() - if count%2 == 0 { - return "hello", nil - } - return "world", nil - }, - } - mockTTransportWithInt16Key := &mocks.MockThriftTTransport{ - ReadMapBeginFunc: func() (keyType, valueType thrift.TType, size int, err error) { - return thrift.I16, thrift.BOOL, 1, nil - }, - ReadI16Func: func() (int16, error) { - return 16, nil - }, - } tests := []struct { name string args args + writer func(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error want interface{} wantErr bool }{ // TODO: Add test cases. { "readMap", - args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, opt: &readerOption{}}, + args{t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{}}, + writeInterfaceMap, map[interface{}]interface{}{"hello": "world"}, false, }, { "readJsonMap", - args{in: mockTTransport, t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, opt: &readerOption{forJSON: true}}, + args{t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{forJSON: true}}, + writeStringMap, map[string]interface{}{"hello": "world"}, false, }, { "readJsonMapWithInt16Key", - args{in: mockTTransportWithInt16Key, t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.I16}, Elem: &descriptor.TypeDescriptor{Type: descriptor.BOOL}}, opt: &readerOption{forJSON: true}}, + args{t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.I16}, Elem: &descriptor.TypeDescriptor{Type: descriptor.BOOL}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{forJSON: true}}, + writeStringMap, map[string]interface{}{"16": false}, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readMap(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := tt.writer(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeInterfaceMap() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readMap(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readMap() error = %v, wantErr %v", err, tt.wantErr) return @@ -519,66 +499,20 @@ func Test_readMap(t *testing.T) { func Test_readStruct(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - read := false - mockTTransport := &mocks.MockThriftTTransport{ - ReadStructBeginFunc: func() (name string, err error) { - return "Demo", nil - }, - ReadFieldBeginFunc: func() (name string, typeID thrift.TType, id int16, err error) { - if !read { - read = true - return "", thrift.STRING, 1, nil - } - return "", thrift.STOP, 0, nil - }, - ReadStringFunc: func() (string, error) { - return "world", nil - }, - } - mockTTransport2 := &mocks.MockThriftTTransport{ - ReadStructBeginFunc: func() (name string, err error) { - return "Demo", nil - }, - ReadFieldBeginFunc: func() (name string, typeID thrift.TType, id int16, err error) { - return "", thrift.STOP, 0, nil - }, - ReadStringFunc: func() (string, error) { - return "world", nil - }, - } - readError := false - mockTTransportError := &mocks.MockThriftTTransport{ - ReadStructBeginFunc: func() (name string, err error) { - return "Demo", nil - }, - ReadFieldBeginFunc: func() (name string, typeID thrift.TType, id int16, err error) { - if !readError { - readError = true - return "", thrift.LIST, 1, nil - } - return "", thrift.STOP, 0, nil - }, - ReadListBeginFunc: func() (elemType thrift.TType, size int, err error) { - return thrift.STRING, 1, nil - }, - ReadStringFunc: func() (string, error) { - return "123", errors.New("need STRING type, but got: I64") - }, - } tests := []struct { name string args args + input interface{} want interface{} wantErr bool }{ // TODO: Add test cases. { "readStruct with setFieldsForEmptyStruct", - args{in: mockTTransport2, t: &descriptor.TypeDescriptor{ + args{t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Struct: &descriptor.StructDescriptor{ FieldsByID: map[int32]*descriptor.FieldDescriptor{ @@ -613,11 +547,26 @@ func Test_readStruct(t *testing.T) { "list-list": [][][]byte(nil), "map-list": map[string][][]byte(nil), }, + map[string]interface{}{ + "string": "", + "byte": byte(0), + "i8": int8(0), + "i16": int16(0), + "i32": int32(0), + "i64": int64(0), + "double": float64(0), + "list": [][]byte(nil), + "set": []bool(nil), + "map": map[string]float64(nil), + "struct": map[string]interface{}(nil), + "list-list": [][][]byte(nil), + "map-list": map[string][][]byte(nil), + }, false, }, { "readStruct with setFieldsForEmptyStruct optional", - args{in: mockTTransport2, t: &descriptor.TypeDescriptor{ + args{t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Struct: &descriptor.StructDescriptor{ FieldsByID: map[int32]*descriptor.FieldDescriptor{ @@ -646,38 +595,63 @@ func Test_readStruct(t *testing.T) { "set": []int16(nil), "map": map[int32]int64(nil), }, + map[string]interface{}{ + "string": "", + "i8": int8(0), + "i16": int16(0), + "i32": int32(0), + "i64": int64(0), + "double": float64(0), + "list": []int8(nil), + "set": []int16(nil), + "map": map[int32]int64(nil), + }, false, }, { "readStruct", - args{in: mockTTransport, t: &descriptor.TypeDescriptor{ + args{t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Struct: &descriptor.StructDescriptor{ FieldsByID: map[int32]*descriptor.FieldDescriptor{ 1: {Name: "hello", Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, }, + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": {Name: "hello", ID: 1, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, }, }}, map[string]interface{}{"hello": "world"}, + map[string]interface{}{"hello": "world"}, false, }, { "readStructError", - args{in: mockTTransportError, t: &descriptor.TypeDescriptor{ + args{t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Struct: &descriptor.StructDescriptor{ FieldsByID: map[int32]*descriptor.FieldDescriptor{ - 1: {Name: "strList", Type: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}}, + 1: {Name: "strList", Type: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I64}}}, + }, + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "strList": {Name: "strList", ID: 1, Type: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}}, }, }, }}, + map[string]interface{}{}, nil, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readStruct(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + err := writeStruct(context.Background(), tt.input, w, tt.args.t, &writerOption{}) + if err != nil { + t.Errorf("writeStruct() error = %v", err) + } + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readStruct(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readStruct() error = %v, wantErr %v", err, tt.wantErr) return @@ -689,26 +663,9 @@ func Test_readStruct(t *testing.T) { func Test_readHTTPResponse(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - read := false - mockTTransport := &mocks.MockThriftTTransport{ - ReadStructBeginFunc: func() (name string, err error) { - return "Demo", nil - }, - ReadFieldBeginFunc: func() (name string, typeID thrift.TType, id int16, err error) { - if !read { - read = true - return "", thrift.STRING, 1, nil - } - return "", thrift.STOP, 0, nil - }, - ReadStringFunc: func() (string, error) { - return "world", nil - }, - } resp := descriptor.NewHTTPResponse() resp.Body = map[string]interface{}{"hello": "world"} tests := []struct { @@ -720,7 +677,7 @@ func Test_readHTTPResponse(t *testing.T) { // TODO: Add test cases. { "readHTTPResponse", - args{in: mockTTransport, t: &descriptor.TypeDescriptor{ + args{t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Struct: &descriptor.StructDescriptor{ FieldsByID: map[int32]*descriptor.FieldDescriptor{ @@ -738,7 +695,12 @@ func Test_readHTTPResponse(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readHTTPResponse(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + w.WriteFieldBegin(thrift.TType(descriptor.STRING), 1) + w.WriteString("world") + w.WriteFieldStop() + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readHTTPResponse(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readHTTPResponse() error = %v, wantErr %v", err, tt.wantErr) return @@ -752,26 +714,9 @@ func Test_readHTTPResponse(t *testing.T) { func Test_readHTTPResponseWithPbBody(t *testing.T) { type args struct { - in thrift.TProtocol t *descriptor.TypeDescriptor opt *readerOption } - read := false - mockTTransport := &mocks.MockThriftTTransport{ - ReadStructBeginFunc: func() (name string, err error) { - return "BizResp", nil - }, - ReadFieldBeginFunc: func() (name string, typeID thrift.TType, id int16, err error) { - if !read { - read = true - return "", thrift.STRING, 1, nil - } - return "", thrift.STOP, 0, nil - }, - ReadStringFunc: func() (string, error) { - return "hello world", nil - }, - } desc, err := getRespPbDesc() if err != nil { t.Error(err) @@ -786,7 +731,7 @@ func Test_readHTTPResponseWithPbBody(t *testing.T) { // TODO: Add test cases. { "readHTTPResponse", - args{in: mockTTransport, t: &descriptor.TypeDescriptor{ + args{t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Struct: &descriptor.StructDescriptor{ FieldsByID: map[int32]*descriptor.FieldDescriptor{ @@ -807,7 +752,12 @@ func Test_readHTTPResponseWithPbBody(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := readHTTPResponse(context.Background(), tt.args.in, tt.args.t, tt.args.opt) + w := thrift.NewBinaryWriter() + w.WriteFieldBegin(thrift.TType(descriptor.STRING), 1) + w.WriteString("hello world") + w.WriteFieldStop() + in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + got, err := readHTTPResponse(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readHTTPResponse() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/generic/thrift/struct.go b/internal/generic/thrift/struct.go similarity index 79% rename from pkg/generic/thrift/struct.go rename to internal/generic/thrift/struct.go index d8870436f0..eaea4bfe08 100644 --- a/pkg/generic/thrift/struct.go +++ b/internal/generic/thrift/struct.go @@ -18,9 +18,12 @@ package thrift import ( "context" + "io" + + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/cloudwego/kitex/pkg/generic/descriptor" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type StructReaderWriter struct { @@ -56,7 +59,7 @@ func (m *WriteStruct) SetBinaryWithBase64(enable bool) { } // Write ... -func (m *WriteStruct) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error { +func (m *WriteStruct) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) if err != nil { return err @@ -70,7 +73,12 @@ func (m *WriteStruct) Write(ctx context.Context, out thrift.TProtocol, msg inter if !hasRequestBase { requestBase = nil } - return wrapStructWriter(ctx, msg, out, ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64}) + binaryWriter := thrift.NewBinaryWriter() + if err = wrapStructWriter(ctx, msg, binaryWriter, ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64}); err != nil { + return err + } + _, err = out.Write(binaryWriter.Bytes()) + return err } // NewReadStruct ... @@ -113,7 +121,7 @@ func (m *ReadStruct) SetSetFieldsForEmptyStruct(mode uint8) { } // Read ... -func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { +func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { fnDsc, err := m.svc.LookupFunctionByMethod(method) if err != nil { return nil, err @@ -122,5 +130,5 @@ func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dat if !isClient { fDsc = fnDsc.Request } - return skipStructReader(ctx, in, fDsc, &readerOption{throwException: true, forJSON: m.forJSON, binaryWithBase64: m.binaryWithBase64, binaryWithByteSlice: m.binaryWithByteSlice, setFieldsForEmptyStruct: m.setFieldsForEmptyStruct}) + return skipStructReader(ctx, thrift.NewBinaryReader(in), fDsc, &readerOption{throwException: true, forJSON: m.forJSON, binaryWithBase64: m.binaryWithBase64, binaryWithByteSlice: m.binaryWithByteSlice, setFieldsForEmptyStruct: m.setFieldsForEmptyStruct}) } diff --git a/pkg/generic/thrift/thrift.go b/internal/generic/thrift/thrift.go similarity index 80% rename from pkg/generic/thrift/thrift.go rename to internal/generic/thrift/thrift.go index 6fca7568d7..862691642c 100644 --- a/pkg/generic/thrift/thrift.go +++ b/internal/generic/thrift/thrift.go @@ -19,8 +19,9 @@ package thrift import ( "context" + "io" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift/base" ) const ( @@ -29,10 +30,10 @@ const ( // MessageReader read from thrift.TProtocol with method type MessageReader interface { - Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) + Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) } // MessageWriter write to thrift.TProtocol type MessageWriter interface { - Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *Base) error + Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error } diff --git a/pkg/generic/thrift/util.go b/internal/generic/thrift/util.go similarity index 98% rename from pkg/generic/thrift/util.go rename to internal/generic/thrift/util.go index 163ad70766..5a5b020ba2 100644 --- a/pkg/generic/thrift/util.go +++ b/internal/generic/thrift/util.go @@ -33,7 +33,7 @@ func assertType(expected, but descriptor.Type) error { return fmt.Errorf("need %s type, but got: %s", expected, but) } -func splitType(t string) (pkg, name string) { +func SplitType(t string) (pkg, name string) { idx := strings.LastIndex(t, ".") if idx == -1 { return "", t diff --git a/pkg/generic/thrift/util_test.go b/internal/generic/thrift/util_test.go similarity index 95% rename from pkg/generic/thrift/util_test.go rename to internal/generic/thrift/util_test.go index 56ea0e594d..313bbaf912 100644 --- a/pkg/generic/thrift/util_test.go +++ b/internal/generic/thrift/util_test.go @@ -26,19 +26,19 @@ import ( ) func TestSplitType(t *testing.T) { - pkg, name := splitType(".A") + pkg, name := SplitType(".A") test.Assert(t, pkg == "") test.Assert(t, name == "A") - pkg, name = splitType("foo.bar.A") + pkg, name = SplitType("foo.bar.A") test.Assert(t, pkg == "foo.bar") test.Assert(t, name == "A") - pkg, name = splitType("A") + pkg, name = SplitType("A") test.Assert(t, pkg == "") test.Assert(t, name == "A") - pkg, name = splitType("") + pkg, name = SplitType("") test.Assert(t, pkg == "") test.Assert(t, name == "") } diff --git a/pkg/generic/thrift/write.go b/internal/generic/thrift/write.go similarity index 66% rename from pkg/generic/thrift/write.go rename to internal/generic/thrift/write.go index e7bcdc9bd0..bb2c27a9aa 100644 --- a/pkg/generic/thrift/write.go +++ b/internal/generic/thrift/write.go @@ -22,21 +22,22 @@ import ( "encoding/json" "fmt" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/tidwall/gjson" + "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/proto" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) type writerOption struct { - requestBase *Base // request base from metahandler + requestBase *base.Base // request base from metahandler // decoding Base64 to binary binaryWithBase64 bool } -type writer func(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error +type writer func(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error type fieldGetter func(val interface{}, field *descriptor.FieldDescriptor) (interface{}, bool) @@ -192,8 +193,8 @@ func nextWriter(sample interface{}, t *descriptor.TypeDescriptor, opt *writerOpt if err != nil { return nil, err } - if t.Type == thrift.SET && tt == thrift.LIST { - tt = thrift.SET + if t.Type == descriptor.SET && tt == descriptor.LIST { + tt = descriptor.SET } return fn, assertType(t.Type, tt) } @@ -206,54 +207,50 @@ func nextJSONWriter(data *gjson.Result, t *descriptor.TypeDescriptor, opt *write return v, fn, nil } -func writeEmptyValue(out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeEmptyValue(out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { switch t.Type { case descriptor.BOOL: - return out.WriteBool(false) + out.WriteBool(false) + return nil case descriptor.I08: - return out.WriteByte(0) + out.WriteByte(0) + return nil case descriptor.I16: - return out.WriteI16(0) + out.WriteI16(0) + return nil case descriptor.I32: - return out.WriteI32(0) + out.WriteI32(0) + return nil case descriptor.I64: - return out.WriteI64(0) + out.WriteI64(0) + return nil case descriptor.DOUBLE: - return out.WriteDouble(0) + out.WriteDouble(0) + return nil case descriptor.STRING: if t.Name == "binary" && opt.binaryWithBase64 { - return out.WriteBinary([]byte{}) + out.WriteBinary([]byte{}) } else { - return out.WriteString("") + out.WriteString("") } + return nil case descriptor.LIST, descriptor.SET: - if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), 0); err != nil { - return err - } - return out.WriteListEnd() + out.WriteListBegin(t.Elem.Type.ToThriftTType(), 0) + return nil case descriptor.MAP: - if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), 0); err != nil { - return err - } - return out.WriteMapEnd() + out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), 0) + return nil case descriptor.STRUCT: - if err := out.WriteStructBegin(t.Name); err != nil { - return err - } - if err := out.WriteFieldStop(); err != nil { - return err - } - return out.WriteStructEnd() + out.WriteFieldStop() + return nil case descriptor.VOID: return nil } return fmt.Errorf("unsupported type:%T", t) } -func wrapStructWriter(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { - if err := out.WriteStructBegin(t.Struct.Name); err != nil { - return err - } +// TODO(marina.sakai): Optimize generic struct writer +func wrapStructWriter(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { for name, field := range t.Struct.FieldsByName { if field.IsException { // generic server ignore the exception, because no description for exception @@ -261,9 +258,7 @@ func wrapStructWriter(ctx context.Context, val interface{}, out thrift.TProtocol continue } if val != nil { - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) writer, err := nextWriter(val, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) @@ -271,30 +266,21 @@ func wrapStructWriter(ctx context.Context, val interface{}, out thrift.TProtocol if err := writer(ctx, val, out, field.Type, opt); err != nil { return fmt.Errorf("writer of field[%s] error %w", name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } - if err := out.WriteFieldStop(); err != nil { - return err - } - return out.WriteStructEnd() + out.WriteFieldStop() + return nil } -func wrapJSONWriter(ctx context.Context, val *gjson.Result, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { - if err := out.WriteStructBegin(t.Struct.Name); err != nil { - return err - } +// TODO(marina.sakai): Optimize generic json writer +func wrapJSONWriter(ctx context.Context, val *gjson.Result, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { for name, field := range t.Struct.FieldsByName { if field.IsException { // generic server ignore the exception, because no description for exception // generic handler just return error continue } - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) v, writer, err := nextJSONWriter(val, field.Type, opt) if err != nil { return fmt.Errorf("nextJSONWriter of field[%s] error %w", name, err) @@ -302,25 +288,21 @@ func wrapJSONWriter(ctx context.Context, val *gjson.Result, out thrift.TProtocol if err := writer(ctx, v, out, field.Type, opt); err != nil { return fmt.Errorf("writer of field[%s] error %w", name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } - } - if err := out.WriteFieldStop(); err != nil { - return err } - return out.WriteStructEnd() + out.WriteFieldStop() + return nil } -func writeVoid(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeVoid(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { return writeStruct(ctx, map[string]interface{}{}, out, t, opt) } -func writeBool(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { - return out.WriteBool(val.(bool)) +func writeBool(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + out.WriteBool(val.(bool)) + return nil } -func writeInt8(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt8(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { var i int8 switch val := val.(type) { case int8: @@ -332,112 +314,128 @@ func writeInt8(ctx context.Context, val interface{}, out thrift.TProtocol, t *de } // compatible with lossless conversion switch t.Type { - case thrift.I08: - return out.WriteByte(i) - case thrift.I16: - return out.WriteI16(int16(i)) - case thrift.I32: - return out.WriteI32(int32(i)) - case thrift.I64: - return out.WriteI64(int64(i)) + case descriptor.I08: + out.WriteByte(i) + return nil + case descriptor.I16: + out.WriteI16(int16(i)) + return nil + case descriptor.I32: + out.WriteI32(int32(i)) + return nil + case descriptor.I64: + out.WriteI64(int64(i)) + return nil } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeInt16(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt16(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { // compatible with lossless conversion i := val.(int16) switch t.Type { - case thrift.I08: + case descriptor.I08: if i&0xff != i { return fmt.Errorf("value is beyond range of i8: %v", i) } - return out.WriteByte(int8(i)) - case thrift.I16: - return out.WriteI16(i) - case thrift.I32: - return out.WriteI32(int32(i)) - case thrift.I64: - return out.WriteI64(int64(i)) + out.WriteByte(int8(i)) + return nil + case descriptor.I16: + out.WriteI16(i) + return nil + case descriptor.I32: + out.WriteI32(int32(i)) + return nil + case descriptor.I64: + out.WriteI64(int64(i)) + return nil } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeInt32(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt32(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { // compatible with lossless conversion i := val.(int32) switch t.Type { - case thrift.I08: + case descriptor.I08: if i&0xff != i { return fmt.Errorf("value is beyond range of i8: %v", i) } - return out.WriteByte(int8(i)) - case thrift.I16: + out.WriteByte(int8(i)) + return nil + case descriptor.I16: if i&0xffff != i { return fmt.Errorf("value is beyond range of i16: %v", i) } - return out.WriteI16(int16(i)) - case thrift.I32: - return out.WriteI32(i) - case thrift.I64: - return out.WriteI64(int64(i)) + out.WriteI16(int16(i)) + return nil + case descriptor.I32: + out.WriteI32(i) + return nil + case descriptor.I64: + out.WriteI64(int64(i)) + return nil } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeInt64(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { // compatible with lossless conversion i := val.(int64) switch t.Type { - case thrift.I08: + case descriptor.I08: if i&0xff != i { return fmt.Errorf("value is beyond range of i8: %v", i) } - return out.WriteByte(int8(i)) - case thrift.I16: + out.WriteByte(int8(i)) + return nil + case descriptor.I16: if i&0xffff != i { return fmt.Errorf("value is beyond range of i16: %v", i) } - return out.WriteI16(int16(i)) - case thrift.I32: + out.WriteI16(int16(i)) + return nil + case descriptor.I32: if i&0xffffffff != i { return fmt.Errorf("value is beyond range of i32: %v", i) } - return out.WriteI32(int32(i)) - case thrift.I64: - return out.WriteI64(i) + out.WriteI32(int32(i)) + return nil + case descriptor.I64: + out.WriteI64(i) + return nil } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeJSONNumber(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSONNumber(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { jn := val.(json.Number) switch t.Type { - case thrift.I08: + case descriptor.I08: i, err := jn.Int64() if err != nil { return err } return writeInt8(ctx, int8(i), out, t, opt) - case thrift.I16: + case descriptor.I16: i, err := jn.Int64() if err != nil { return err } return writeInt16(ctx, int16(i), out, t, opt) - case thrift.I32: + case descriptor.I32: i, err := jn.Int64() if err != nil { return err } return writeInt32(ctx, int32(i), out, t, opt) - case thrift.I64: + case descriptor.I64: i, err := jn.Int64() if err != nil { return err } return writeInt64(ctx, i, out, t, opt) - case thrift.DOUBLE: + case descriptor.DOUBLE: i, err := jn.Float64() if err != nil { return err @@ -447,65 +445,63 @@ func writeJSONNumber(ctx context.Context, val interface{}, out thrift.TProtocol, return nil } -func writeJSONFloat64(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSONFloat64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { i := val.(float64) switch t.Type { - case thrift.I08: + case descriptor.I08: return writeInt8(ctx, int8(i), out, t, opt) - case thrift.I16: + case descriptor.I16: return writeInt16(ctx, int16(i), out, t, opt) - case thrift.I32: + case descriptor.I32: return writeInt32(ctx, int32(i), out, t, opt) - case thrift.I64: + case descriptor.I64: return writeInt64(ctx, int64(i), out, t, opt) - case thrift.DOUBLE: + case descriptor.DOUBLE: return writeFloat64(ctx, i, out, t, opt) } return fmt.Errorf("need number type, but got: %s", t.Type) } -func writeFloat64(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { - return out.WriteDouble(val.(float64)) +func writeFloat64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + out.WriteDouble(val.(float64)) + return nil } -func writeString(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { - return out.WriteString(val.(string)) +func writeString(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + out.WriteString(val.(string)) + return nil } -func writeBase64Binary(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeBase64Binary(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { bytes, err := base64.StdEncoding.DecodeString(val.(string)) if err != nil { return err } - return out.WriteBinary(bytes) + out.WriteBinary(bytes) + return nil } -func writeBinary(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { - return out.WriteBinary(val.([]byte)) +func writeBinary(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + out.WriteBinary(val.([]byte)) + return nil } -func writeBinaryList(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeBinaryList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { l := val.([]byte) length := len(l) - if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), length); err != nil { - return err - } + out.WriteListBegin(t.Elem.Type.ToThriftTType(), length) for _, b := range l { - if err := out.WriteByte(int8(b)); err != nil { - return err - } + out.WriteByte(int8(b)) } - return out.WriteListEnd() + return nil } -func writeList(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { l := val.([]interface{}) length := len(l) - if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), length); err != nil { - return err - } + out.WriteListBegin(t.Elem.Type.ToThriftTType(), length) if length == 0 { - return out.WriteListEnd() + return nil } var ( writer writer @@ -527,17 +523,15 @@ func writeList(ctx context.Context, val interface{}, out thrift.TProtocol, t *de } } } - return out.WriteListEnd() + return nil } -func writeJSONList(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSONList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { l := val.([]gjson.Result) length := len(l) - if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), length); err != nil { - return err - } + out.WriteListBegin(t.Elem.Type.ToThriftTType(), length) if length == 0 { - return out.WriteListEnd() + return nil } for _, elem := range l { v, writer, err := nextJSONWriter(&elem, t.Elem, opt) @@ -548,17 +542,15 @@ func writeJSONList(ctx context.Context, val interface{}, out thrift.TProtocol, t return err } } - return out.WriteListEnd() + return nil } -func writeInterfaceMap(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInterfaceMap(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { m := val.(map[interface{}]interface{}) length := len(m) - if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length); err != nil { - return err - } + out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length) if length == 0 { - return out.WriteMapEnd() + return nil } var ( keyWriter writer @@ -589,17 +581,15 @@ func writeInterfaceMap(ctx context.Context, val interface{}, out thrift.TProtoco } } } - return out.WriteMapEnd() + return nil } -func writeStringMap(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeStringMap(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { m := val.(map[string]interface{}) length := len(m) - if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length); err != nil { - return err - } + out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length) if length == 0 { - return out.WriteMapEnd() + return nil } var ( @@ -634,17 +624,15 @@ func writeStringMap(ctx context.Context, val interface{}, out thrift.TProtocol, } } } - return out.WriteMapEnd() + return nil } -func writeStringJSONMap(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeStringJSONMap(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { m := val.(map[string]gjson.Result) length := len(m) - if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length); err != nil { - return err - } + out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length) if length == 0 { - return out.WriteMapEnd() + return nil } var ( @@ -673,10 +661,10 @@ func writeStringJSONMap(ctx context.Context, val interface{}, out thrift.TProtoc return err } } - return out.WriteMapEnd() + return nil } -func writeRequestBase(ctx context.Context, val interface{}, out thrift.TProtocol, field *descriptor.FieldDescriptor, opt *writerOption) error { +func writeRequestBase(ctx context.Context, val interface{}, out *thrift.BinaryWriter, field *descriptor.FieldDescriptor, opt *writerOption) error { if st, ok := val.(map[string]interface{}); ok { // copy from user's Extra if ext, ok := st["Extra"]; ok { @@ -710,17 +698,18 @@ func writeRequestBase(ctx context.Context, val interface{}, out thrift.TProtocol } } } - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } - if err := opt.requestBase.Write(out); err != nil { - return err + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + sz := opt.requestBase.BLength() + buf := make([]byte, sz) + opt.requestBase.FastWrite(buf) + for _, b := range buf { + out.WriteByte(int8(b)) } - return out.WriteFieldEnd() + return nil } // writeStruct iter with Descriptor, can check the field's required and others -func writeStruct(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeStruct(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { var fg fieldGetter switch val.(type) { case map[string]interface{}: @@ -729,10 +718,7 @@ func writeStruct(ctx context.Context, val interface{}, out thrift.TProtocol, t * fg = pbGetter } - err := out.WriteStructBegin(t.Struct.Name) - if err != nil { - return err - } + var err error for name, field := range t.Struct.FieldsByName { elem, ok := fg(val, field) if field.Type.IsRequestBase && opt.requestBase != nil { @@ -746,15 +732,10 @@ func writeStruct(ctx context.Context, val interface{}, out thrift.TProtocol, t * if elem == nil || !ok { if !field.Optional { // empty fields don't need value-mapping here, since writeEmptyValue decides zero value based on Thrift type - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) if err := writeEmptyValue(out, field.Type, opt); err != nil { return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } else { // normal fields if field.ValueMapping != nil { @@ -763,9 +744,7 @@ func writeStruct(ctx context.Context, val interface{}, out thrift.TProtocol, t * return err } } - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) writer, err := nextWriter(elem, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) @@ -773,28 +752,20 @@ func writeStruct(ctx context.Context, val interface{}, out thrift.TProtocol, t * if err := writer(ctx, elem, out, field.Type, opt); err != nil { return fmt.Errorf("writer of field[%s] error %w", name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } - if err := out.WriteFieldStop(); err != nil { - return err - } - return out.WriteStructEnd() + out.WriteFieldStop() + return nil } -func writeHTTPRequest(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { req := val.(*descriptor.HTTPRequest) defer func() { if req.Params != nil { req.Params.Recycle() } }() - if err := out.WriteStructBegin(t.Struct.Name); err != nil { - return err - } for name, field := range t.Struct.FieldsByName { v, err := requestMappingValue(ctx, req, field) if err != nil { @@ -809,15 +780,10 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out thrift.TProtocol if v == nil { if !field.Optional { - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) if err := writeEmptyValue(out, field.Type, opt); err != nil { return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } else { if field.ValueMapping != nil { @@ -826,9 +792,7 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out thrift.TProtocol return err } } - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) writer, err := nextWriter(v, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) @@ -836,24 +800,15 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out thrift.TProtocol if err := writer(ctx, v, out, field.Type, opt); err != nil { return fmt.Errorf("writer of field[%s] error %w", name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } - if err := out.WriteFieldStop(); err != nil { - return err - } - return out.WriteStructEnd() + out.WriteFieldStop() + return nil } -func writeJSON(ctx context.Context, val interface{}, out thrift.TProtocol, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSON(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { data := val.(*gjson.Result) - err := out.WriteStructBegin(t.Struct.Name) - if err != nil { - return err - } for name, field := range t.Struct.FieldsByName { elem := data.Get(name) if field.Type.IsRequestBase && opt.requestBase != nil { @@ -866,35 +821,23 @@ func writeJSON(ctx context.Context, val interface{}, out thrift.TProtocol, t *de if elem.Type == gjson.Null { if !field.Optional { - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) if err := writeEmptyValue(out, field.Type, opt); err != nil { return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } else { v, writer, err := nextJSONWriter(&elem, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) } - if err := out.WriteFieldBegin(field.Name, field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { - return err - } + out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) if err := writer(ctx, v, out, field.Type, opt); err != nil { return fmt.Errorf("writer of field[%s] error %w", name, err) } - if err := out.WriteFieldEnd(); err != nil { - return err - } } } - if err := out.WriteFieldStop(); err != nil { - return err - } - return out.WriteStructEnd() + out.WriteFieldStop() + return nil } diff --git a/pkg/generic/thrift/write_test.go b/internal/generic/thrift/write_test.go similarity index 71% rename from pkg/generic/thrift/write_test.go rename to internal/generic/thrift/write_test.go index d71e40c444..9545fe66c5 100644 --- a/pkg/generic/thrift/write_test.go +++ b/internal/generic/thrift/write_test.go @@ -17,78 +17,29 @@ package thrift import ( - "bytes" "context" "encoding/base64" "encoding/json" - "errors" - "fmt" - "reflect" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/jhump/protoreflect/desc/protoparse" "github.com/tidwall/gjson" - "github.com/cloudwego/kitex/internal/mocks" + "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/proto" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) func Test_nextWriter(t *testing.T) { // add some testcases type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := func(v interface{}) *mocks.MockThriftTTransport { - toint := func(i interface{}) int64 { - switch ti := i.(type) { - case int8: - return int64(ti) - case int16: - return int64(ti) - case int32: - return int64(ti) - case int64: - return int64(ti) - case int: - return int64(ti) - default: - t.Errorf("type %v not support toint", reflect.TypeOf(v)) - } - return 0 - } - return &mocks.MockThriftTTransport{ - WriteByteFunc: func(val int8) error { - test.Assert(t, val == int8(toint(v))) - return nil - }, - WriteI16Func: func(val int16) error { - test.Assert(t, val == int16(toint(v))) - return nil - }, - WriteI32Func: func(val int32) error { - test.Assert(t, val == int32(toint(v))) - return nil - }, - WriteI64Func: func(val int64) error { - test.Assert(t, val == toint(v)) - return nil - }, - WriteDoubleFunc: func(val float64) error { - test.Assert(t, val == v.(float64)) - return nil - }, - WriteBoolFunc: func(val bool) error { - test.Assert(t, val == v.(bool)) - return nil - }, - } - } tests := []struct { name string @@ -100,13 +51,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri8 Success", args{ val: int8(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -116,13 +67,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri16 Success", args{ val: int16(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -132,13 +83,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri32 Success", args{ val: int32(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -148,13 +99,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri64 Success", args{ val: int64(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -164,13 +115,13 @@ func Test_nextWriter(t *testing.T) { "nextWriterbool Success", args{ val: true, - out: mockTTransport(true), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.BOOL, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -180,13 +131,13 @@ func Test_nextWriter(t *testing.T) { "nextWriterdouble Success", args{ val: float64(1.0), - out: mockTTransport(float64(1.0)), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -196,13 +147,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri8 Failed", args{ val: 10000000, - out: mockTTransport(10000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -212,13 +163,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri16 Failed", args{ val: 10000000, - out: mockTTransport(10000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -228,13 +179,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri32 Failed", args{ val: 10000000, - out: mockTTransport(10000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -244,13 +195,13 @@ func Test_nextWriter(t *testing.T) { "nextWriteri64 Failed", args{ val: "10000000", - out: mockTTransport(10000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, }, opt: &writerOption{ - requestBase: &Base{}, + requestBase: &base.Base{}, binaryWithBase64: false, }, }, @@ -281,16 +232,10 @@ func Test_nextWriter(t *testing.T) { func Test_writeVoid(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - test.Assert(t, name == "") - return nil - }, - } tests := []struct { name string @@ -302,7 +247,7 @@ func Test_writeVoid(t *testing.T) { "writeVoid", args{ val: 1, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.VOID, Struct: &descriptor.StructDescriptor{}, @@ -323,16 +268,10 @@ func Test_writeVoid(t *testing.T) { func Test_writeBool(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteBoolFunc: func(val bool) error { - test.Assert(t, val) - return nil - }, - } tests := []struct { name string args args @@ -343,7 +282,7 @@ func Test_writeBool(t *testing.T) { "writeBool", args{ val: true, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.BOOL, Struct: &descriptor.StructDescriptor{}, @@ -364,22 +303,10 @@ func Test_writeBool(t *testing.T) { func Test_writeInt8(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := func(v int8) *mocks.MockThriftTTransport { - return &mocks.MockThriftTTransport{ - WriteByteFunc: func(val int8) error { - test.Assert(t, val == v) - return nil - }, - WriteI16Func: func(val int16) error { - test.Assert(t, val == int16(v)) - return nil - }, - } - } tests := []struct { name string args args @@ -390,7 +317,7 @@ func Test_writeInt8(t *testing.T) { "writeInt8", args{ val: int8(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -401,8 +328,8 @@ func Test_writeInt8(t *testing.T) { { name: "writeInt8 byte", args: args{ - val: byte(128), - out: mockTTransport(-128), // overflow + val: byte(128), // overflow + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -414,7 +341,7 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 error", args: args{ val: int16(2), - out: mockTTransport(2), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -426,7 +353,7 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i16", args: args{ val: int8(2), - out: mockTTransport(2), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -438,7 +365,7 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i32", args: args{ val: int8(2), - out: mockTTransport(2), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -450,7 +377,7 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i64", args: args{ val: int8(2), - out: mockTTransport(2), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -462,7 +389,7 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i64", args: args{ val: int8(2), - out: mockTTransport(2), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -483,16 +410,10 @@ func Test_writeInt8(t *testing.T) { func Test_writeJSONNumber(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteByteFunc: func(val int8) error { - test.Assert(t, val == 1) - return nil - }, - } tests := []struct { name string args args @@ -503,7 +424,7 @@ func Test_writeJSONNumber(t *testing.T) { "writeJSONNumber", args{ val: json.Number("1"), - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -524,16 +445,10 @@ func Test_writeJSONNumber(t *testing.T) { func Test_writeJSONFloat64(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteByteFunc: func(val int8) error { - test.Assert(t, val == 1) - return nil - }, - } tests := []struct { name string args args @@ -544,7 +459,7 @@ func Test_writeJSONFloat64(t *testing.T) { "writeJSONFloat64", args{ val: 1.0, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -556,7 +471,7 @@ func Test_writeJSONFloat64(t *testing.T) { "writeJSONFloat64 bool Failed", args{ val: 1.0, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.BOOL, Struct: &descriptor.StructDescriptor{}, @@ -577,26 +492,10 @@ func Test_writeJSONFloat64(t *testing.T) { func Test_writeInt16(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := func(v int64) *mocks.MockThriftTTransport { - return &mocks.MockThriftTTransport{ - WriteI32Func: func(val int32) error { - test.Assert(t, val == int32(v)) - return nil - }, - WriteI16Func: func(val int16) error { - test.Assert(t, val == int16(v)) - return nil - }, - WriteByteFunc: func(val int8) error { - test.Assert(t, val == int8(v)) - return nil - }, - } - } tests := []struct { name string args args @@ -607,7 +506,7 @@ func Test_writeInt16(t *testing.T) { "writeInt16", args{ val: int16(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -619,7 +518,7 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt8 Success", args{ val: int16(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -631,7 +530,7 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt8 Failed", args{ val: int16(10000), - out: mockTTransport(10000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -643,7 +542,7 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt32 Success", args{ val: int16(10000), - out: mockTTransport(10000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -655,7 +554,7 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt64 Success", args{ val: int16(10000), - out: mockTTransport(10000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -667,7 +566,7 @@ func Test_writeInt16(t *testing.T) { "writeInt16 Failed", args{ val: int16(10000), - out: mockTTransport(10000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -688,30 +587,10 @@ func Test_writeInt16(t *testing.T) { func Test_writeInt32(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := func(v int64) *mocks.MockThriftTTransport { - return &mocks.MockThriftTTransport{ - WriteI64Func: func(val int64) error { - test.Assert(t, val == v) - return nil - }, - WriteI32Func: func(val int32) error { - test.Assert(t, val == int32(v)) - return nil - }, - WriteI16Func: func(val int16) error { - test.Assert(t, val == int16(v)) - return nil - }, - WriteByteFunc: func(val int8) error { - test.Assert(t, val == int8(v)) - return nil - }, - } - } tests := []struct { name string @@ -723,7 +602,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32 Success", args{ val: int32(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -735,7 +614,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32 Failed", args{ val: int32(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -747,7 +626,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt8 Success", args{ val: int32(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -759,7 +638,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt8 Failed", args{ val: int32(100000), - out: mockTTransport(100000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -771,7 +650,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt16 success", args{ val: int32(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -783,7 +662,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt16 Failed", args{ val: int32(100000), - out: mockTTransport(100000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -795,7 +674,7 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt64 Success", args{ val: int32(10000000), - out: mockTTransport(10000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -816,30 +695,10 @@ func Test_writeInt32(t *testing.T) { func Test_writeInt64(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := func(v int64) *mocks.MockThriftTTransport { - return &mocks.MockThriftTTransport{ - WriteI64Func: func(val int64) error { - test.Assert(t, val == v) - return nil - }, - WriteI32Func: func(val int32) error { - test.Assert(t, val == int32(v)) - return nil - }, - WriteI16Func: func(val int16) error { - test.Assert(t, val == int16(v)) - return nil - }, - WriteByteFunc: func(val int8) error { - test.Assert(t, val == int8(v)) - return nil - }, - } - } tests := []struct { name string args args @@ -850,7 +709,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64 Success", args{ val: int64(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -862,7 +721,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64 Failed", args{ val: int64(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -874,7 +733,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt8 Success", args{ val: int64(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -886,7 +745,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt8 failed", args{ val: int64(1000), - out: mockTTransport(1000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -898,7 +757,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt16 Success", args{ val: int64(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -910,7 +769,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt16 failed", args{ val: int64(100000000000), - out: mockTTransport(100000000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -922,7 +781,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt32 Success", args{ val: int64(1), - out: mockTTransport(1), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -934,7 +793,7 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt32 failed", args{ val: int64(100000000000), - out: mockTTransport(100000000000), + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -955,16 +814,10 @@ func Test_writeInt64(t *testing.T) { func Test_writeFloat64(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteDoubleFunc: func(val float64) error { - test.Assert(t, val == 1.0) - return nil - }, - } tests := []struct { name string args args @@ -975,7 +828,7 @@ func Test_writeFloat64(t *testing.T) { "writeFloat64", args{ val: 1.0, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -996,20 +849,10 @@ func Test_writeFloat64(t *testing.T) { func Test_writeString(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStringFunc: func(val string) error { - test.Assert(t, val == stringInput) - return nil - }, - WriteBinaryFunc: func(val []byte) error { - test.DeepEqual(t, val, binaryInput) - return nil - }, - } tests := []struct { name string args args @@ -1020,7 +863,7 @@ func Test_writeString(t *testing.T) { "writeString", args{ val: stringInput, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRING, Struct: &descriptor.StructDescriptor{}, @@ -1041,20 +884,10 @@ func Test_writeString(t *testing.T) { func Test_writeBase64String(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStringFunc: func(val string) error { - test.Assert(t, val == stringInput) - return nil - }, - WriteBinaryFunc: func(val []byte) error { - test.DeepEqual(t, val, binaryInput) - return nil - }, - } tests := []struct { name string args args @@ -1065,7 +898,7 @@ func Test_writeBase64String(t *testing.T) { "writeBase64Binary", // write to binary field with base64 string args{ val: base64.StdEncoding.EncodeToString(binaryInput), - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Name: "binary", Type: descriptor.STRING, @@ -1088,16 +921,10 @@ func Test_writeBase64String(t *testing.T) { func Test_writeBinary(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteBinaryFunc: func(val []byte) error { - test.Assert(t, reflect.DeepEqual(val, []byte(stringInput))) - return nil - }, - } tests := []struct { name string args args @@ -1108,7 +935,7 @@ func Test_writeBinary(t *testing.T) { "writeBinary", args{ val: []byte(stringInput), - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRING, Struct: &descriptor.StructDescriptor{}, @@ -1132,11 +959,6 @@ func Test_writeBinaryList(t *testing.T) { t *descriptor.TypeDescriptor opt *writerOption } - type params struct { - listBeginErr error - writeByteErr error - listEndErr error - } commonArgs := args{ val: []byte(stringInput), t: &descriptor.TypeDescriptor{ @@ -1148,7 +970,6 @@ func Test_writeBinaryList(t *testing.T) { tests := []struct { name string args args - params params wantErr bool }{ { @@ -1156,30 +977,6 @@ func Test_writeBinaryList(t *testing.T) { args: commonArgs, wantErr: false, }, - { - name: "list begin error", - args: commonArgs, - params: params{ - listBeginErr: errors.New("test error"), - }, - wantErr: true, - }, - { - name: "write byte error", - args: commonArgs, - params: params{ - writeByteErr: errors.New("test error"), - }, - wantErr: true, - }, - { - name: "list end error", - args: commonArgs, - params: params{ - listEndErr: errors.New("test error"), - }, - wantErr: true, - }, { name: "empty slice", args: args{ @@ -1195,31 +992,13 @@ func Test_writeBinaryList(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var writtenData []byte - var endHit bool - mockTTransport := &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.BYTE) - test.Assert(t, size == len(tt.args.val)) - return tt.params.listBeginErr - }, - WriteByteFunc: func(val int8) error { - writtenData = append(writtenData, byte(val)) - return tt.params.writeByteErr - }, - WriteListEndFunc: func() error { - endHit = true - return tt.params.listEndErr - }, - } - - if err := writeBinaryList(context.Background(), tt.args.val, mockTTransport, tt.args.t, + binaryWriter := thrift.NewBinaryWriter() + if err := writeBinaryList(context.Background(), tt.args.val, binaryWriter, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeBinary() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr { - test.Assert(t, bytes.Equal(tt.args.val, writtenData)) - test.Assert(t, endHit == true) + test.Assert(t, len(tt.args.val)+5 == len(binaryWriter.Bytes())) } }) } @@ -1228,7 +1007,7 @@ func Test_writeBinaryList(t *testing.T) { func Test_writeList(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -1242,13 +1021,7 @@ func Test_writeList(t *testing.T) { "writeList", args{ val: []interface{}{stringInput}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1261,13 +1034,7 @@ func Test_writeList(t *testing.T) { "writeListWithNil", args{ val: []interface{}{stringInput, nil, stringInput}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRING) - test.Assert(t, size == 3) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1280,13 +1047,7 @@ func Test_writeList(t *testing.T) { "writeListWithNilOnly", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1299,13 +1060,7 @@ func Test_writeList(t *testing.T) { "writeListWithNextWriterError", args{ val: []interface{}{stringInput}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.I08) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I08}, @@ -1327,7 +1082,7 @@ func Test_writeList(t *testing.T) { func Test_writeInterfaceMap(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -1341,14 +1096,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMap", args{ val: map[interface{}]interface{}{"hello": "world"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1362,14 +1110,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithNil", args{ val: map[interface{}]interface{}{"hello": "world", "hi": nil, "hey": "kitex"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 3) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1383,14 +1124,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithNilOnly", args{ val: map[interface{}]interface{}{"hello": nil}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1404,14 +1138,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithElemNextWriterError", args{ val: map[interface{}]interface{}{"hello": "world"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.BOOL) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1425,14 +1152,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithKeyWriterError", args{ val: map[interface{}]interface{}{"hello": "world"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.I08) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.I08}, @@ -1456,7 +1176,7 @@ func Test_writeInterfaceMap(t *testing.T) { func Test_writeStringMap(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -1470,14 +1190,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMap", args{ val: map[string]interface{}{"hello": "world"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1491,14 +1204,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMapWithNil", args{ val: map[string]interface{}{"hello": "world", "hi": nil, "hey": "kitex"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 3) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1512,14 +1218,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMapWithNilOnly", args{ val: map[string]interface{}{"hello": nil}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1533,14 +1232,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMapWithElemNextWriterError", args{ val: map[string]interface{}{"hello": "world"}, - out: &mocks.MockThriftTTransport{ - WriteMapBeginFunc: func(keyType, valueType thrift.TType, size int) error { - test.Assert(t, keyType == thrift.STRING) - test.Assert(t, valueType == thrift.BOOL) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1563,41 +1255,10 @@ func Test_writeStringMap(t *testing.T) { func Test_writeStruct(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - test.Assert(t, name == "Demo") - return nil - }, - WriteFieldBeginFunc: func(name string, typeID thrift.TType, id int16) error { - test.Assert(t, name == "hello") - test.Assert(t, typeID == thrift.STRING) - test.Assert(t, id == 1) - return nil - }, - } - mockTTransportError := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - test.Assert(t, name == "Demo") - return nil - }, - WriteFieldBeginFunc: func(name string, typeID thrift.TType, id int16) error { - test.Assert(t, name == "strList") - test.Assert(t, typeID == thrift.LIST) - test.Assert(t, id == 1) - return nil - }, - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRING) - return nil - }, - WriteStringFunc: func(value string) error { - return errors.New("need STRING type, but got: I64") - }, - } tests := []struct { name string args args @@ -1608,7 +1269,7 @@ func Test_writeStruct(t *testing.T) { "writeStruct", args{ val: map[string]interface{}{"hello": "world"}, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1630,7 +1291,7 @@ func Test_writeStruct(t *testing.T) { "writeStructRequired", args{ val: map[string]interface{}{"hello": nil}, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1652,7 +1313,7 @@ func Test_writeStruct(t *testing.T) { "writeStructOptional", args{ val: map[string]interface{}{}, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1671,7 +1332,7 @@ func Test_writeStruct(t *testing.T) { "writeStructError", args{ val: map[string]interface{}{"strList": []interface{}{int64(123)}}, - out: mockTTransportError, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1699,22 +1360,10 @@ func Test_writeStruct(t *testing.T) { func Test_writeHTTPRequest(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - test.Assert(t, name == "Demo") - return nil - }, - WriteFieldBeginFunc: func(name string, typeID thrift.TType, id int16) error { - test.Assert(t, name == "hello") - test.Assert(t, typeID == thrift.STRING) - test.Assert(t, id == 1) - return nil - }, - } tests := []struct { name string args args @@ -1727,7 +1376,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{"hello": "world"}, }, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1753,7 +1402,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{"hello": nil}, }, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1780,7 +1429,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{"hello": nil}, }, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1807,7 +1456,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{}, }, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1841,7 +1490,7 @@ func Test_writeHTTPRequest(t *testing.T) { func Test_writeHTTPRequestWithPbBody(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -1885,40 +1534,11 @@ func Test_writeHTTPRequestWithPbBody(t *testing.T) { "writeStructSuccess", args{ val: req, - out: &mocks.MockThriftTTransport{ - WriteI32Func: func(value int32) error { - test.Assert(t, value == 1234) - return nil - }, - WriteStringFunc: func(value string) error { - test.Assert(t, value == "John") - return nil - }, - }, - t: typeDescriptor, + out: thrift.NewBinaryWriter(), + t: typeDescriptor, }, false, }, - { - "writeStructFail", - args{ - val: req, - out: &mocks.MockThriftTTransport{ - WriteI32Func: func(value int32) error { - test.Assert(t, value == 1234) - return nil - }, - WriteStringFunc: func(value string) error { - if value == "John" { - return fmt.Errorf("MakeSureThisExecuted") - } - return nil - }, - }, - t: typeDescriptor, - }, - true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1964,26 +1584,10 @@ func Test_writeRequestBase(t *testing.T) { type args struct { ctx context.Context val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter field *descriptor.FieldDescriptor opt *writerOption } - depth := 1 - mockTTransport := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - test.Assert(t, name == "Base", name) - return nil - }, - WriteFieldBeginFunc: func(name string, typeID thrift.TType, id int16) error { - if depth == 1 { - test.Assert(t, name == "base", name) - test.Assert(t, typeID == thrift.STRUCT, typeID) - test.Assert(t, id == 255) - depth++ - } - return nil - }, - } tests := []struct { name string args args @@ -1995,13 +1599,13 @@ func Test_writeRequestBase(t *testing.T) { "writeStruct", args{ val: map[string]interface{}{"Extra": map[string]interface{}{"hello": "world"}}, - out: mockTTransport, + out: thrift.NewBinaryWriter(), field: &descriptor.FieldDescriptor{ Name: "base", ID: 255, Type: &descriptor.TypeDescriptor{Type: descriptor.STRUCT, Name: "base.Base"}, }, - opt: &writerOption{requestBase: &Base{}}, + opt: &writerOption{requestBase: &base.Base{}}, }, false, }, @@ -2018,22 +1622,10 @@ func Test_writeRequestBase(t *testing.T) { func Test_writeJSON(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - test.Assert(t, name == "Demo") - return nil - }, - WriteFieldBeginFunc: func(name string, typeID thrift.TType, id int16) error { - test.Assert(t, name == "hello") - test.Assert(t, typeID == thrift.STRING) - test.Assert(t, id == 1) - return nil - }, - } data := gjson.Parse(`{"hello": "world"}`) dataEmpty := gjson.Parse(`{"hello": nil}`) tests := []struct { @@ -2046,7 +1638,7 @@ func Test_writeJSON(t *testing.T) { "writeJSON", args{ val: &data, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -2068,7 +1660,7 @@ func Test_writeJSON(t *testing.T) { "writeJSONRequired", args{ val: &dataEmpty, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -2090,7 +1682,7 @@ func Test_writeJSON(t *testing.T) { "writeJSONOptional", args{ val: &dataEmpty, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -2118,18 +1710,10 @@ func Test_writeJSON(t *testing.T) { func Test_writeJSONBase(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } - mockTTransport := &mocks.MockThriftTTransport{ - WriteStructBeginFunc: func(name string) error { - return nil - }, - WriteFieldBeginFunc: func(name string, typeID thrift.TType, id int16) error { - return nil - }, - } data := gjson.Parse(`{"hello":"world", "base": {"Extra": {"hello":"world"}}}`) tests := []struct { name string @@ -2140,7 +1724,7 @@ func Test_writeJSONBase(t *testing.T) { "writeJSONBase", args{ val: &data, - out: mockTTransport, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -2160,7 +1744,7 @@ func Test_writeJSONBase(t *testing.T) { }, }, opt: &writerOption{ - requestBase: &Base{ + requestBase: &base.Base{ LogID: "logID-12345", Caller: "Caller.Name", }, @@ -2182,7 +1766,7 @@ func Test_writeJSONBase(t *testing.T) { func Test_getDefaultValueAndWriter(t *testing.T) { type args struct { val interface{} - out thrift.TProtocol + out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -2196,13 +1780,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "bool", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.BOOL) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.BOOL}, @@ -2215,13 +1793,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i08", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.I08) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I08}, @@ -2234,13 +1806,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i16", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.I16) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I16}, @@ -2253,13 +1819,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i32", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.I32) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I32}, @@ -2272,13 +1832,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i64", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.I64) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I64}, @@ -2291,13 +1845,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "double", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.DOUBLE) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.DOUBLE}, @@ -2310,13 +1858,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "stringBinary", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), opt: &writerOption{ binaryWithBase64: true, }, @@ -2335,13 +1877,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "stringNonBinary", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRING) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -2354,11 +1890,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "list", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -2374,11 +1906,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "set", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -2394,13 +1922,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "map", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.MAP) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -2417,13 +1939,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "struct", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.STRUCT) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -2449,13 +1965,7 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "void", args{ val: []interface{}{nil}, - out: &mocks.MockThriftTTransport{ - WriteListBeginFunc: func(elemType thrift.TType, size int) error { - test.Assert(t, elemType == thrift.VOID) - test.Assert(t, size == 1) - return nil - }, - }, + out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index 05da3f79c0..142fd9f483 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -23,10 +23,11 @@ package generic import ( context "context" + "io" reflect "reflect" - thrift "github.com/apache/thrift/lib/go/thrift" - thrift0 "github.com/cloudwego/kitex/pkg/generic/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/base" + gomock "github.com/golang/mock/gomock" ) @@ -54,7 +55,7 @@ func (m *MockMessageReader) EXPECT() *MockMessageReaderMockRecorder { } // Read mocks base method. -func (m *MockMessageReader) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) (interface{}, error) { +func (m *MockMessageReader) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Read", ctx, method, in) ret0, _ := ret[0].(interface{}) @@ -92,7 +93,7 @@ func (m *MockMessageWriter) EXPECT() *MockMessageWriterMockRecorder { } // Write mocks base method. -func (m *MockMessageWriter) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, method string, isClient bool, requestBase *thrift0.Base) error { +func (m *MockMessageWriter) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", ctx, out, msg, requestBase) ret0, _ := ret[0].(error) diff --git a/pkg/generic/binarythrift_codec_test.go b/pkg/generic/binarythrift_codec_test.go index 1c79a1df20..48fbda0fea 100644 --- a/pkg/generic/binarythrift_codec_test.go +++ b/pkg/generic/binarythrift_codec_test.go @@ -122,7 +122,7 @@ func TestBinaryThriftCodecExceptionError(t *testing.T) { // empty method err = btc.Marshal(ctx, cliMsg, rwbuf) - test.Assert(t, err.Error() == "rawThriftBinaryCodec Marshal exception failed, err: empty methodName in thrift Marshal") + test.Assert(t, err.Error() == "rawThriftBinaryCodec Marshal exception failed, err: empty methodName in thrift Marshal", err) cliMsg.RPCInfoFunc = func() rpcinfo.RPCInfo { return newMockRPCInfo() diff --git a/pkg/generic/descriptor/http_mapping.go b/pkg/generic/descriptor/http_mapping.go index eef97d7597..416fde4839 100644 --- a/pkg/generic/descriptor/http_mapping.go +++ b/pkg/generic/descriptor/http_mapping.go @@ -20,7 +20,7 @@ import ( "context" "errors" - "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/internal/generic/proto" ) // HTTPMapping http mapping annotation diff --git a/pkg/generic/descriptor/render.go b/pkg/generic/descriptor/render.go index decbefa416..52e8ce4502 100644 --- a/pkg/generic/descriptor/render.go +++ b/pkg/generic/descriptor/render.go @@ -20,7 +20,7 @@ import ( "encoding/json" "net/http" - "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/internal/generic/proto" ) type Renderer interface { diff --git a/pkg/generic/descriptor/type.go b/pkg/generic/descriptor/type.go index 28a49a6bef..437da7da05 100644 --- a/pkg/generic/descriptor/type.go +++ b/pkg/generic/descriptor/type.go @@ -16,7 +16,7 @@ package descriptor -import thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" +import "github.com/cloudwego/gopkg/protocol/thrift" // Type constants in the Thrift protocol type Type byte diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index 9d5f09bdb4..74d94264b9 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -19,12 +19,13 @@ package generic import ( "context" "fmt" + "io" - gproto "github.com/cloudwego/kitex/pkg/generic/proto" - gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/gopkg/protocol/thrift/base" + + "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/internal/generic/thrift" codecProto "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" - codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -108,16 +109,14 @@ type WithCodec interface { type Args struct { Request interface{} Method string - base *gthrift.Base + base *base.Base inner interface{} } var ( - _ codecThrift.MessageReaderWithMethodWithContext = (*Args)(nil) - _ codecThrift.MessageWriterWithMethodWithContext = (*Args)(nil) - _ codecProto.MessageWriterWithContext = (*Args)(nil) - _ codecProto.MessageReaderWithMethodWithContext = (*Args)(nil) - _ WithCodec = (*Args)(nil) + _ codecProto.MessageWriterWithContext = (*Args)(nil) + _ codecProto.MessageReaderWithMethodWithContext = (*Args)(nil) + _ WithCodec = (*Args)(nil) ) // SetCodec ... @@ -127,17 +126,17 @@ func (g *Args) SetCodec(inner interface{}) { func (g *Args) GetOrSetBase() interface{} { if g.base == nil { - g.base = gthrift.NewBase() + g.base = base.NewBase() } return g.base } // Write ... -func (g *Args) Write(ctx context.Context, method string, out thrift.TProtocol) error { +func (g *Args) Write(ctx context.Context, method string, out io.Writer) error { if err, ok := g.inner.(error); ok { return err } - if w, ok := g.inner.(gthrift.MessageWriter); ok { + if w, ok := g.inner.(thrift.MessageWriter); ok { return w.Write(ctx, out, g.Request, method, true, g.base) } return fmt.Errorf("unexpected Args writer type: %T", g.inner) @@ -147,18 +146,18 @@ func (g *Args) WritePb(ctx context.Context, method string) (interface{}, error) if err, ok := g.inner.(error); ok { return nil, err } - if w, ok := g.inner.(gproto.MessageWriter); ok { + if w, ok := g.inner.(proto.MessageWriter); ok { return w.Write(ctx, g.Request, method, true) } return nil, fmt.Errorf("unexpected Args writer type: %T", g.inner) } // Read ... -func (g *Args) Read(ctx context.Context, method string, dataLen int, in thrift.TProtocol) error { +func (g *Args) Read(ctx context.Context, method string, dataLen int, in io.Reader) error { if err, ok := g.inner.(error); ok { return err } - if rw, ok := g.inner.(gthrift.MessageReader); ok { + if rw, ok := g.inner.(thrift.MessageReader); ok { g.Method = method var err error g.Request, err = rw.Read(ctx, method, false, dataLen, in) @@ -171,7 +170,7 @@ func (g *Args) ReadPb(ctx context.Context, method string, in []byte) error { if err, ok := g.inner.(error); ok { return err } - if w, ok := g.inner.(gproto.MessageReader); ok { + if w, ok := g.inner.(proto.MessageReader); ok { g.Method = method var err error g.Request, err = w.Read(ctx, method, false, in) @@ -192,11 +191,9 @@ type Result struct { } var ( - _ codecThrift.MessageReaderWithMethodWithContext = (*Result)(nil) - _ codecThrift.MessageWriterWithMethodWithContext = (*Result)(nil) - _ codecProto.MessageWriterWithContext = (*Result)(nil) - _ codecProto.MessageReaderWithMethodWithContext = (*Result)(nil) - _ WithCodec = (*Result)(nil) + _ codecProto.MessageWriterWithContext = (*Result)(nil) + _ codecProto.MessageReaderWithMethodWithContext = (*Result)(nil) + _ WithCodec = (*Result)(nil) ) // SetCodec ... @@ -205,11 +202,11 @@ func (r *Result) SetCodec(inner interface{}) { } // Write ... -func (r *Result) Write(ctx context.Context, method string, out thrift.TProtocol) error { +func (r *Result) Write(ctx context.Context, method string, out io.Writer) error { if err, ok := r.inner.(error); ok { return err } - if w, ok := r.inner.(gthrift.MessageWriter); ok { + if w, ok := r.inner.(thrift.MessageWriter); ok { return w.Write(ctx, out, r.Success, method, false, nil) } return fmt.Errorf("unexpected Result writer type: %T", r.inner) @@ -219,18 +216,18 @@ func (r *Result) WritePb(ctx context.Context, method string) (interface{}, error if err, ok := r.inner.(error); ok { return nil, err } - if w, ok := r.inner.(gproto.MessageWriter); ok { + if w, ok := r.inner.(proto.MessageWriter); ok { return w.Write(ctx, r.Success, method, false) } return nil, fmt.Errorf("unexpected Result writer type: %T", r.inner) } // Read ... -func (r *Result) Read(ctx context.Context, method string, dataLen int, in thrift.TProtocol) error { +func (r *Result) Read(ctx context.Context, method string, dataLen int, in io.Reader) error { if err, ok := r.inner.(error); ok { return err } - if w, ok := r.inner.(gthrift.MessageReader); ok { + if w, ok := r.inner.(thrift.MessageReader); ok { var err error r.Success, err = w.Read(ctx, method, true, dataLen, in) return err @@ -242,7 +239,7 @@ func (r *Result) ReadPb(ctx context.Context, method string, in []byte) error { if err, ok := r.inner.(error); ok { return err } - if w, ok := r.inner.(gproto.MessageReader); ok { + if w, ok := r.inner.(proto.MessageReader); ok { var err error r.Success, err = w.Read(ctx, method, true, in) return err diff --git a/pkg/generic/generic_service_test.go b/pkg/generic/generic_service_test.go index 5ffb910db5..4b1f30ed9c 100644 --- a/pkg/generic/generic_service_test.go +++ b/pkg/generic/generic_service_test.go @@ -22,13 +22,12 @@ import ( "strings" "testing" + gbase "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/golang/mock/gomock" mocks "github.com/cloudwego/kitex/internal/mocks/generic" "github.com/cloudwego/kitex/internal/test" - gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" - codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/utils" ) @@ -36,8 +35,7 @@ import ( func TestGenericService(t *testing.T) { ctx := context.Background() method := "test" - out := remote.NewWriterBuffer(256) - tProto := codecThrift.NewBinaryProtocol(out) + buffer := remote.NewReaderWriterBuffer(256) ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -45,7 +43,7 @@ func TestGenericService(t *testing.T) { argWriteInner, resultWriteInner := mocks.NewMockMessageWriter(ctrl), mocks.NewMockMessageWriter(ctrl) rInner := mocks.NewMockMessageReader(ctrl) // Read expect - rInner.EXPECT().Read(ctx, method, tProto).Return("test", nil).AnyTimes() + rInner.EXPECT().Read(ctx, method, buffer).Return("test", nil).AnyTimes() // Args... arg := newGenericServiceCallArgs() @@ -56,21 +54,21 @@ func TestGenericService(t *testing.T) { test.Assert(t, base != nil) a.SetCodec(struct{}{}) // write not ok - err := a.Write(ctx, method, tProto) + err := a.Write(ctx, method, buffer) test.Assert(t, err.Error() == "unexpected Args writer type: struct {}") // Write expect - argWriteInner.EXPECT().Write(ctx, tProto, a.Request, a.GetOrSetBase()).Return(nil) + argWriteInner.EXPECT().Write(ctx, buffer, a.Request, a.GetOrSetBase()).Return(nil) a.SetCodec(argWriteInner) // write ok - err = a.Write(ctx, method, tProto) + err = a.Write(ctx, method, buffer) test.Assert(t, err == nil, err) // read not ok - err = a.Read(ctx, method, 0, tProto) + err = a.Read(ctx, method, 0, buffer) test.Assert(t, strings.Contains(err.Error(), "unexpected Args reader type")) // read ok a.SetCodec(rInner) - err = a.Read(ctx, method, 0, tProto) + err = a.Read(ctx, method, 0, buffer) test.Assert(t, err == nil, err) // Result... @@ -79,20 +77,20 @@ func TestGenericService(t *testing.T) { test.Assert(t, ok == true) // write not ok - err = r.Write(ctx, method, tProto) + err = r.Write(ctx, method, buffer) test.Assert(t, err.Error() == "unexpected Result writer type: ") // Write expect - resultWriteInner.EXPECT().Write(ctx, tProto, r.Success, (*gthrift.Base)(nil)).Return(nil).AnyTimes() + resultWriteInner.EXPECT().Write(ctx, buffer, r.Success, (*gbase.Base)(nil)).Return(nil).AnyTimes() r.SetCodec(resultWriteInner) // write ok - err = r.Write(ctx, method, tProto) + err = r.Write(ctx, method, buffer) test.Assert(t, err == nil) // read not ok - err = r.Read(ctx, method, 0, tProto) + err = r.Read(ctx, method, 0, buffer) test.Assert(t, strings.Contains(err.Error(), "unexpected Result reader type")) // read ok r.SetCodec(rInner) - err = r.Read(ctx, method, 0, tProto) + err = r.Read(ctx, method, 0, buffer) test.Assert(t, err == nil) r.SetSuccess(nil) diff --git a/pkg/generic/httppbthrift_codec.go b/pkg/generic/httppbthrift_codec.go index ea59513f81..503f6bbcf6 100644 --- a/pkg/generic/httppbthrift_codec.go +++ b/pkg/generic/httppbthrift_codec.go @@ -26,9 +26,9 @@ import ( "github.com/jhump/protoreflect/desc" + "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/proto" - "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/httppbthrift_codec_test.go b/pkg/generic/httppbthrift_codec_test.go index 2af104805f..63025854a5 100644 --- a/pkg/generic/httppbthrift_codec_test.go +++ b/pkg/generic/httppbthrift_codec_test.go @@ -24,8 +24,8 @@ import ( "reflect" "testing" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" - gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -70,8 +70,8 @@ func TestHTTPPbThriftCodec(t *testing.T) { test.Assert(t, htc.svcName == "ExampleService") rw := htc.getMessageReaderWriter() - _, ok := rw.(gthrift.MessageWriter) + _, ok := rw.(thrift.MessageWriter) test.Assert(t, ok) - _, ok = rw.(gthrift.MessageReader) + _, ok = rw.(thrift.MessageReader) test.Assert(t, ok) } diff --git a/pkg/generic/httpthrift_codec.go b/pkg/generic/httpthrift_codec.go index b66a37ed11..85ca2d4128 100644 --- a/pkg/generic/httpthrift_codec.go +++ b/pkg/generic/httpthrift_codec.go @@ -26,8 +26,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/httpthrift_codec_test.go b/pkg/generic/httpthrift_codec_test.go index 245b1906fc..9f69ba5b26 100644 --- a/pkg/generic/httpthrift_codec_test.go +++ b/pkg/generic/httpthrift_codec_test.go @@ -24,8 +24,8 @@ import ( "github.com/bytedance/sonic" "github.com/cloudwego/dynamicgo/conv" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" - gthrift "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -71,9 +71,9 @@ func TestHttpThriftCodec(t *testing.T) { test.Assert(t, !ok) rw = htc.getMessageReaderWriter() - _, ok = rw.(gthrift.MessageWriter) + _, ok = rw.(thrift.MessageWriter) test.Assert(t, ok) - _, ok = rw.(gthrift.MessageReader) + _, ok = rw.(thrift.MessageReader) test.Assert(t, ok) } @@ -103,9 +103,9 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) { test.Assert(t, htc.svcName == "ExampleService") rw := htc.getMessageReaderWriter() - _, ok := rw.(gthrift.MessageWriter) + _, ok := rw.(thrift.MessageWriter) test.Assert(t, ok) - _, ok = rw.(gthrift.MessageReader) + _, ok = rw.(thrift.MessageReader) test.Assert(t, ok) } diff --git a/pkg/generic/json_test/generic_test.go b/pkg/generic/json_test/generic_test.go index 7d23cdd104..f955789773 100644 --- a/pkg/generic/json_test/generic_test.go +++ b/pkg/generic/json_test/generic_test.go @@ -74,7 +74,7 @@ func testThrift(t *testing.T) { test.Assert(t, err == nil, err) respStr, ok := resp.(string) test.Assert(t, ok) - test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), "world") + test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), gjson.Get(respStr, "Msg").String()) // extend method resp, err = cli.GenericCall(context.Background(), "ExtendMethod", reqExtendMsg, callopt.WithRPCTimeout(100*time.Second)) @@ -97,7 +97,7 @@ func testThriftWithDynamicGo(t *testing.T) { test.Assert(t, err == nil, err) respStr, ok := resp.(string) test.Assert(t, ok) - test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), "world") + test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), gjson.Get(respStr, "Msg").String()) // client without dynamicgo @@ -109,7 +109,7 @@ func testThriftWithDynamicGo(t *testing.T) { test.Assert(t, err == nil, err) respStr, ok = resp.(string) test.Assert(t, ok) - test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), "world") + test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), gjson.Get(respStr, "Msg").String()) // server side: // write: dynamicgo (amd64 && go1.16), fallback (arm || !go1.16) @@ -119,7 +119,7 @@ func testThriftWithDynamicGo(t *testing.T) { test.Assert(t, err == nil, err) respStr, ok = resp.(string) test.Assert(t, ok) - test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), "world") + test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), gjson.Get(respStr, "Msg").String()) svr.Stop() } diff --git a/pkg/generic/jsonpb_codec.go b/pkg/generic/jsonpb_codec.go index ac58e21ed1..9cbeb735fb 100644 --- a/pkg/generic/jsonpb_codec.go +++ b/pkg/generic/jsonpb_codec.go @@ -25,7 +25,7 @@ import ( "github.com/cloudwego/dynamicgo/conv" dproto "github.com/cloudwego/dynamicgo/proto" - "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" diff --git a/pkg/generic/jsonpb_codec_test.go b/pkg/generic/jsonpb_codec_test.go index 8913d3369b..7778aca1cb 100644 --- a/pkg/generic/jsonpb_codec_test.go +++ b/pkg/generic/jsonpb_codec_test.go @@ -23,8 +23,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" dproto "github.com/cloudwego/dynamicgo/proto" + gproto "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/internal/test" - gproto "github.com/cloudwego/kitex/pkg/generic/proto" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/jsonthrift_codec.go b/pkg/generic/jsonthrift_codec.go index 377d507ca1..20d2b5e0f6 100644 --- a/pkg/generic/jsonthrift_codec.go +++ b/pkg/generic/jsonthrift_codec.go @@ -23,8 +23,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" diff --git a/pkg/generic/jsonthrift_codec_test.go b/pkg/generic/jsonthrift_codec_test.go index 82ff0bb1fe..6d00d9a57a 100644 --- a/pkg/generic/jsonthrift_codec_test.go +++ b/pkg/generic/jsonthrift_codec_test.go @@ -21,8 +21,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/mapthrift_codec.go b/pkg/generic/mapthrift_codec.go index 78e2a449e4..5284820c59 100644 --- a/pkg/generic/mapthrift_codec.go +++ b/pkg/generic/mapthrift_codec.go @@ -21,8 +21,8 @@ import ( "errors" "sync/atomic" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/mapthrift_codec_test.go b/pkg/generic/mapthrift_codec_test.go index 86630b842d..635f59963b 100644 --- a/pkg/generic/mapthrift_codec_test.go +++ b/pkg/generic/mapthrift_codec_test.go @@ -19,8 +19,8 @@ package generic import ( "testing" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/pb_descriptor_provider.go b/pkg/generic/pb_descriptor_provider.go index 2a2e115453..27c7e29d22 100644 --- a/pkg/generic/pb_descriptor_provider.go +++ b/pkg/generic/pb_descriptor_provider.go @@ -19,7 +19,7 @@ package generic import ( dproto "github.com/cloudwego/dynamicgo/proto" - "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/internal/generic/proto" ) // PbDescriptorProvider provide service descriptor diff --git a/pkg/generic/pbidl_provider.go b/pkg/generic/pbidl_provider.go index d115c51054..587f4e9314 100644 --- a/pkg/generic/pbidl_provider.go +++ b/pkg/generic/pbidl_provider.go @@ -24,7 +24,7 @@ import ( dproto "github.com/cloudwego/dynamicgo/proto" "github.com/jhump/protoreflect/desc/protoparse" - "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/internal/generic/proto" ) type PbContentProvider struct { diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go deleted file mode 100644 index 5eb02672a3..0000000000 --- a/pkg/generic/thrift/base.go +++ /dev/null @@ -1,962 +0,0 @@ -/* - * Copyright 2021 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "fmt" - - bthrift "github.com/cloudwego/gopkg/protocol/thrift" - - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -type TrafficEnv struct { - Open bool `thrift:"Open,1" json:"Open"` - - Env string `thrift:"Env,2" json:"Env"` -} - -func NewTrafficEnv() *TrafficEnv { - return &TrafficEnv{ - Open: false, - Env: "", - } -} - -func (p *TrafficEnv) GetOpen() bool { - return p.Open -} - -func (p *TrafficEnv) GetEnv() string { - return p.Env -} - -func (p *TrafficEnv) SetOpen(val bool) { - p.Open = val -} - -func (p *TrafficEnv) SetEnv(val string) { - p.Env = val -} - -var fieldIDToName_TrafficEnv = map[int16]string{ - 1: "Open", - 2: "Env", -} - -func (p *TrafficEnv) Read(iprot thrift.TProtocol) (err error) { - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.BOOL { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.STRING { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_TrafficEnv[fieldId]), err) -SkipFieldError: - return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *TrafficEnv) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(); err != nil { - return err - } else { - p.Open = v - } - return nil -} - -func (p *TrafficEnv) ReadField2(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Env = v - } - return nil -} - -func (p *TrafficEnv) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("TrafficEnv"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *TrafficEnv) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Open", thrift.BOOL, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteBool(p.Open); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *TrafficEnv) writeField2(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Env", thrift.STRING, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Env); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *TrafficEnv) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("TrafficEnv(%+v)", *p) -} - -type Base struct { - LogID string `thrift:"LogID,1" json:"LogID"` - - Caller string `thrift:"Caller,2" json:"Caller"` - - Addr string `thrift:"Addr,3" json:"Addr"` - - Client string `thrift:"Client,4" json:"Client"` - - TrafficEnv *TrafficEnv `thrift:"TrafficEnv,5" json:"TrafficEnv,omitempty"` - - Extra map[string]string `thrift:"Extra,6" json:"Extra,omitempty"` -} - -func NewBase() *Base { - return &Base{ - LogID: "", - Caller: "", - Addr: "", - Client: "", - } -} - -func (p *Base) GetLogID() string { - return p.LogID -} - -func (p *Base) GetCaller() string { - return p.Caller -} - -func (p *Base) GetAddr() string { - return p.Addr -} - -func (p *Base) GetClient() string { - return p.Client -} - -var Base_TrafficEnv_DEFAULT *TrafficEnv - -func (p *Base) GetTrafficEnv() *TrafficEnv { - if !p.IsSetTrafficEnv() { - return Base_TrafficEnv_DEFAULT - } - return p.TrafficEnv -} - -var Base_Extra_DEFAULT map[string]string - -func (p *Base) GetExtra() map[string]string { - if !p.IsSetExtra() { - return Base_Extra_DEFAULT - } - return p.Extra -} - -func (p *Base) SetLogID(val string) { - p.LogID = val -} - -func (p *Base) SetCaller(val string) { - p.Caller = val -} - -func (p *Base) SetAddr(val string) { - p.Addr = val -} - -func (p *Base) SetClient(val string) { - p.Client = val -} - -func (p *Base) SetTrafficEnv(val *TrafficEnv) { - p.TrafficEnv = val -} - -func (p *Base) SetExtra(val map[string]string) { - p.Extra = val -} - -var fieldIDToName_Base = map[int16]string{ - 1: "LogID", - 2: "Caller", - 3: "Addr", - 4: "Client", - 5: "TrafficEnv", - 6: "Extra", -} - -func (p *Base) IsSetTrafficEnv() bool { - return p.TrafficEnv != nil -} - -func (p *Base) IsSetExtra() bool { - return p.Extra != nil -} - -func (p *Base) Read(iprot thrift.TProtocol) (err error) { - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRING { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.STRING { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.STRING { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 4: - if fieldTypeId == thrift.STRING { - if err = p.ReadField4(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 5: - if fieldTypeId == thrift.STRUCT { - if err = p.ReadField5(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 6: - if fieldTypeId == thrift.MAP { - if err = p.ReadField6(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Base[fieldId]), err) -SkipFieldError: - return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Base) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.LogID = v - } - return nil -} - -func (p *Base) ReadField2(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Caller = v - } - return nil -} - -func (p *Base) ReadField3(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Addr = v - } - return nil -} - -func (p *Base) ReadField4(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.Client = v - } - return nil -} - -func (p *Base) ReadField5(iprot thrift.TProtocol) error { - p.TrafficEnv = NewTrafficEnv() - if err := p.TrafficEnv.Read(iprot); err != nil { - return err - } - return nil -} - -func (p *Base) ReadField6(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.Extra = make(map[string]string, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - - var _val string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _val = v - } - - p.Extra[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *Base) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("Base"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - if err = p.writeField4(oprot); err != nil { - fieldId = 4 - goto WriteFieldError - } - if err = p.writeField5(oprot); err != nil { - fieldId = 5 - goto WriteFieldError - } - if err = p.writeField6(oprot); err != nil { - fieldId = 6 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *Base) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("LogID", thrift.STRING, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.LogID); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *Base) writeField2(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Caller", thrift.STRING, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Caller); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *Base) writeField3(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Addr", thrift.STRING, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Addr); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) -} - -func (p *Base) writeField4(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("Client", thrift.STRING, 4); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.Client); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) -} - -func (p *Base) writeField5(oprot thrift.TProtocol) (err error) { - if p.IsSetTrafficEnv() { - if err = oprot.WriteFieldBegin("TrafficEnv", thrift.STRUCT, 5); err != nil { - goto WriteFieldBeginError - } - if err := p.TrafficEnv.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) -} - -func (p *Base) writeField6(oprot thrift.TProtocol) (err error) { - if p.IsSetExtra() { - if err = oprot.WriteFieldBegin("Extra", thrift.MAP, 6); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Extra)); err != nil { - return err - } - for k, v := range p.Extra { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) -} - -func (p *Base) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("Base(%+v)", *p) -} - -type BaseResp struct { - StatusMessage string `thrift:"StatusMessage,1" json:"StatusMessage"` - - StatusCode int32 `thrift:"StatusCode,2" json:"StatusCode"` - - Extra map[string]string `thrift:"Extra,3" json:"Extra,omitempty"` -} - -func NewBaseResp() *BaseResp { - return &BaseResp{ - StatusMessage: "", - StatusCode: 0, - } -} - -func (p *BaseResp) GetStatusMessage() string { - return p.StatusMessage -} - -func (p *BaseResp) GetStatusCode() int32 { - return p.StatusCode -} - -var BaseResp_Extra_DEFAULT map[string]string - -func (p *BaseResp) GetExtra() map[string]string { - if !p.IsSetExtra() { - return BaseResp_Extra_DEFAULT - } - return p.Extra -} - -func (p *BaseResp) SetStatusMessage(val string) { - p.StatusMessage = val -} - -func (p *BaseResp) SetStatusCode(val int32) { - p.StatusCode = val -} - -func (p *BaseResp) SetExtra(val map[string]string) { - p.Extra = val -} - -var fieldIDToName_BaseResp = map[int16]string{ - 1: "StatusMessage", - 2: "StatusCode", - 3: "Extra", -} - -func (p *BaseResp) IsSetExtra() bool { - return p.Extra != nil -} - -func (p *BaseResp) Read(iprot thrift.TProtocol) (err error) { - var fieldTypeId thrift.TType - var fieldId int16 - - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } - - for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() - if err != nil { - goto ReadFieldBeginError - } - if fieldTypeId == thrift.STOP { - break - } - - switch fieldId { - case 1: - if fieldTypeId == thrift.STRING { - if err = p.ReadField1(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 2: - if fieldTypeId == thrift.I32 { - if err = p.ReadField2(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - case 3: - if fieldTypeId == thrift.MAP { - if err = p.ReadField3(iprot); err != nil { - goto ReadFieldError - } - } else { - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - default: - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldError - } - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_BaseResp[fieldId]), err) -SkipFieldError: - return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *BaseResp) ReadField1(iprot thrift.TProtocol) error { - if v, err := iprot.ReadString(); err != nil { - return err - } else { - p.StatusMessage = v - } - return nil -} - -func (p *BaseResp) ReadField2(iprot thrift.TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { - return err - } else { - p.StatusCode = v - } - return nil -} - -func (p *BaseResp) ReadField3(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - p.Extra = make(map[string]string, size) - for i := 0; i < size; i++ { - var _key string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } - - var _val string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _val = v - } - - p.Extra[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { - return err - } - return nil -} - -func (p *BaseResp) Write(oprot thrift.TProtocol) (err error) { - var fieldId int16 - if err = oprot.WriteStructBegin("BaseResp"); err != nil { - goto WriteStructBeginError - } - if p != nil { - if err = p.writeField1(oprot); err != nil { - fieldId = 1 - goto WriteFieldError - } - if err = p.writeField2(oprot); err != nil { - fieldId = 2 - goto WriteFieldError - } - if err = p.writeField3(oprot); err != nil { - fieldId = 3 - goto WriteFieldError - } - - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldError: - return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) -WriteFieldStopError: - return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *BaseResp) writeField1(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("StatusMessage", thrift.STRING, 1); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(p.StatusMessage); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) -} - -func (p *BaseResp) writeField2(oprot thrift.TProtocol) (err error) { - if err = oprot.WriteFieldBegin("StatusCode", thrift.I32, 2); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteI32(p.StatusCode); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) -} - -func (p *BaseResp) writeField3(oprot thrift.TProtocol) (err error) { - if p.IsSetExtra() { - if err = oprot.WriteFieldBegin("Extra", thrift.MAP, 3); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Extra)); err != nil { - return err - } - for k, v := range p.Extra { - - if err := oprot.WriteString(k); err != nil { - return err - } - - if err := oprot.WriteString(v); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return bthrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) -WriteFieldEndError: - return bthrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) -} - -func (p *BaseResp) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("BaseResp(%+v)", *p) -} diff --git a/pkg/generic/thrift/parse.go b/pkg/generic/thrift/parse.go index 7576a7953b..5d3520ead2 100644 --- a/pkg/generic/thrift/parse.go +++ b/pkg/generic/thrift/parse.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" @@ -307,7 +308,7 @@ func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor if ty, ok := cache[t.Name]; ok { return ty, nil } - typePkg, typeName := splitType(t.Name) + typePkg, typeName := thrift.SplitType(t.Name) if typePkg != "" { ref, ok := tree.GetReference(typePkg) if !ok { diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index beafd77638..8986abf03c 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -143,8 +143,13 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re } } + // generic call + if msg, ok := data.(genericWriter); ok { + return encodeGenericThrift(out, ctx, methodName, msgType, seqID, msg) + } + // fallback to old thrift way (slow) - if err := encodeBasicThrift(out, ctx, methodName, msgType, seqID, data, message.RPCRole()); err == nil || err != errEncodeMismatchMsgType { + if err := encodeBasicThrift(out, ctx, methodName, msgType, seqID, data); err == nil || err != errEncodeMismatchMsgType { return err } @@ -182,8 +187,20 @@ func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.M return nw.MallocAck(mallocLen) } -// encodeBasicThrift encode with the old thrift way (slow) -func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}, rpcRole remote.RPCRole) error { +func encodeGenericThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, msg genericWriter) error { + binaryWriter := thrift.NewBinaryWriter() + binaryWriter.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID) + if _, err := out.Write(binaryWriter.Bytes()); err != nil { + return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) + } + if err := msg.Write(ctx, method, out); err != nil { + return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) + } + return nil +} + +// encodeBasicThrift encode with the old athrift way (slow) +func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error { if err := verifyMarshalBasicThriftDataType(data); err != nil { return err } @@ -191,7 +208,7 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string if err := tProt.WriteMessageBegin(method, athrift.TMessageType(msgType), seqID); err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error())) } - if err := marshalBasicThriftData(ctx, tProt, data, method, rpcRole); err != nil { + if err := marshalBasicThriftData(tProt, data); err != nil { return err } if err := tProt.WriteMessageEnd(); err != nil { @@ -234,7 +251,14 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r ri := message.RPCInfo() rpcinfo.Record(ctx, ri, stats.WaitReadStart, nil) - err = c.unmarshalThriftData(ctx, tProt, methodName, data, message.RPCRole(), dataLen) + if msg, ok := data.(genericReader); ok { + err = msg.Read(ctx, methodName, dataLen, in) + if err != nil { + err = remote.NewTransError(remote.ProtocolError, err) + } + } else { + err = c.unmarshalThriftData(tProt, data, dataLen) + } rpcinfo.Record(ctx, ri, stats.WaitReadFinish, err) if err != nil { return err @@ -269,12 +293,12 @@ func (c thriftCodec) Name() string { return serviceinfo.Thrift.String() } -// MessageWriter write to thrift.TProtocol +// MessageWriter write to athrift.TProtocol type MessageWriter interface { Write(oprot athrift.TProtocol) error } -// MessageReader read from thrift.TProtocol +// MessageReader read from athrift.TProtocol type MessageReader interface { Read(oprot athrift.TProtocol) error } @@ -287,18 +311,6 @@ type genericReader interface { // used by pkg/generic Read(ctx context.Context, method string, dataLen int, r io.Reader) error } -// MessageWriterWithMethodWithContext write to thrift.TProtocol -// TODO(marina.sakai): remove it after we use the new genericWriter interface -type MessageWriterWithMethodWithContext interface { - Write(ctx context.Context, method string, oprot athrift.TProtocol) error -} - -// MessageReaderWithMethodWithContext read from thrift.TProtocol with method -// TODO(marina.sakai): remove it after we use the new genericReader interface -type MessageReaderWithMethodWithContext interface { - Read(ctx context.Context, method string, dataLen int, iprot athrift.TProtocol) error -} - // ThriftMsgFastCodec ... // Deprecated: use `github.com/cloudwego/gopkg/protocol/thrift.FastCodec` type ThriftMsgFastCodec = thrift.FastCodec diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 673141e482..db0ccc3503 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -72,7 +72,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ // fallback to old thrift way (slow) transport := athrift.NewTMemoryBufferLen(marshalThriftBufferSize) tProt := athrift.NewTBinaryProtocol(transport, true, true) - if err := marshalBasicThriftData(ctx, tProt, data, "", -1); err != nil { + if err := marshalBasicThriftData(tProt, data); err != nil { return nil, err } return transport.Bytes(), nil @@ -82,8 +82,6 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ func verifyMarshalBasicThriftDataType(data interface{}) error { switch data.(type) { case MessageWriter: - case MessageWriterWithMethodWithContext: - case genericWriter: default: return errEncodeMismatchMsgType } @@ -92,15 +90,11 @@ func verifyMarshalBasicThriftDataType(data interface{}) error { // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) // It uses the old thrift way which is much slower than FastCodec and Frugal -func marshalBasicThriftData(ctx context.Context, tProt athrift.TProtocol, data interface{}, method string, rpcRole remote.RPCRole) error { +func marshalBasicThriftData(tProt athrift.TProtocol, data interface{}) error { var err error switch msg := data.(type) { case MessageWriter: err = msg.Write(tProt) - case MessageWriterWithMethodWithContext: - err = msg.Write(ctx, method, tProt) - case genericWriter: - err = msg.Write(ctx, method, tProt.Transport()) default: return errEncodeMismatchMsgType } @@ -136,7 +130,7 @@ func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method c = defaultCodec } tProt := NewBinaryProtocol(remote.NewReaderBuffer(buf)) - err := c.unmarshalThriftData(ctx, tProt, method, data, -1, len(buf)) + err := c.unmarshalThriftData(tProt, data, len(buf)) if err == nil { tProt.Recycle() } @@ -177,7 +171,7 @@ func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, data // unmarshalThriftData only decodes the data (after methodName, msgType and seqId) // method is only used for generic calls -func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProtocol, method string, data interface{}, rpcRole remote.RPCRole, dataLen int) error { +func (c thriftCodec) unmarshalThriftData(tProt *BinaryProtocol, data interface{}, dataLen int) error { // decode with hyper unmarshal if c.IsSet(FrugalRead) && c.hyperMessageUnmarshalAvailable(data, dataLen) { return c.hyperUnmarshal(tProt, data, dataLen) @@ -203,7 +197,7 @@ func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProto } // fallback to old thrift way (slow) - return decodeBasicThriftData(ctx, tProt, method, rpcRole, dataLen, data) + return decodeBasicThriftData(tProt, data) } func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { @@ -232,8 +226,6 @@ func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dat func verifyUnmarshalBasicThriftDataType(data interface{}) error { switch data.(type) { case MessageReader: - case MessageReaderWithMethodWithContext: - case genericReader: default: return errDecodeMismatchMsgType } @@ -241,15 +233,11 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error { } // decodeBasicThriftData decode thrift body the old way (slow) -func decodeBasicThriftData(ctx context.Context, tProt athrift.TProtocol, method string, rpcRole remote.RPCRole, dataLen int, data interface{}) error { +func decodeBasicThriftData(tProt athrift.TProtocol, data interface{}) error { var err error switch t := data.(type) { case MessageReader: err = t.Read(tProt) - case MessageReaderWithMethodWithContext: - err = t.Read(ctx, method, dataLen, tProt) - case genericReader: - err = t.Read(ctx, method, dataLen, tProt.Transport()) default: return errDecodeMismatchMsgType } diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index a75b88e3c8..9f23735cf6 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -44,13 +44,13 @@ var ( func TestMarshalBasicThriftData(t *testing.T) { t.Run("invalid-data", func(t *testing.T) { - err := marshalBasicThriftData(context.Background(), nil, 0, "", -1) + err := marshalBasicThriftData(nil, 0) test.Assert(t, err == errEncodeMismatchMsgType, err) }) t.Run("valid-data", func(t *testing.T) { transport := athrift.NewTMemoryBufferLen(1024) tProt := athrift.NewTBinaryProtocol(transport, true, true) - err := marshalBasicThriftData(context.Background(), tProt, mocks.ToApacheCodec(mockReq), "", -1) + err := marshalBasicThriftData(tProt, mocks.ToApacheCodec(mockReq)) test.Assert(t, err == nil, err) result := transport.Bytes() test.Assert(t, reflect.DeepEqual(result, mockReqThrift), result) @@ -80,19 +80,19 @@ func Test_decodeBasicThriftData(t *testing.T) { t.Run("empty-input", func(t *testing.T) { req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{})) - err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, mocks.ToApacheCodec(req)) + err := decodeBasicThriftData(tProt, mocks.ToApacheCodec(req)) test.Assert(t, err != nil, err) }) t.Run("invalid-input", func(t *testing.T) { req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{0xff})) - err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, mocks.ToApacheCodec(req)) + err := decodeBasicThriftData(tProt, mocks.ToApacheCodec(req)) test.Assert(t, err != nil, err) }) t.Run("normal-input", func(t *testing.T) { req := &mocks.MockReq{} tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) - err := decodeBasicThriftData(context.Background(), tProt, "mock", -1, 0, mocks.ToApacheCodec(req)) + err := decodeBasicThriftData(tProt, mocks.ToApacheCodec(req)) checkDecodeResult(t, err, req) }) } @@ -130,7 +130,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) + err := codec.unmarshalThriftData(tProt, req, 0) checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, @@ -154,7 +154,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultMockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) + err := codec.unmarshalThriftData(tProt, req, 0) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in FastCodec using SkipDecoder Buffer")) }) @@ -177,18 +177,6 @@ func TestUnmarshalThriftException(t *testing.T) { test.Assert(t, transErr.Error() == errMessage, transErr) } -func Test_verifyMarshalBasicThriftDataType(t *testing.T) { - err := verifyMarshalBasicThriftDataType(&mockWithContext{}) - test.Assert(t, err == nil, err) - // data that is not part of basic thrift: in thrift_frugal_amd64_test.go: Test_verifyMarshalThriftDataFrugal -} - -func Test_verifyUnmarshalBasicThriftDataType(t *testing.T) { - err := verifyUnmarshalBasicThriftDataType(&mockWithContext{}) - test.Assert(t, err == nil, err) - // data that is not part of basic thrift: in thrift_frugal_amd64_test.go: Test_verifyUnmarshalThriftDataFrugal -} - func Test_getSkippedStructBuffer(t *testing.T) { // string length is 6 but only got "hello" faultThrift := []byte{ diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index 9c4e2dab60..3985268171 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -231,7 +231,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) + err := codec.unmarshalThriftData(tProt, req, 0) checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, @@ -255,7 +255,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultMockReqThrift)) defer tProt.Recycle() // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(context.Background(), tProt, "mock", req, -1, 0) + err := codec.unmarshalThriftData(tProt, req, 0) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in Frugal using SkipDecoder Buffer")) }) diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index dd84489d99..6e7a18b5e1 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -19,6 +19,7 @@ package thrift import ( "context" "errors" + "io" "testing" "github.com/cloudwego/gopkg/protocol/thrift" @@ -27,7 +28,6 @@ import ( "github.com/cloudwego/kitex/internal/mocks" mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -63,18 +63,18 @@ func init() { } type mockWithContext struct { - ReadFunc func(ctx context.Context, method string, dataLen int, oprot athrift.TProtocol) error - WriteFunc func(ctx context.Context, method string, oprot athrift.TProtocol) error + ReadFunc func(ctx context.Context, method string, dataLen int, oprot io.Reader) error + WriteFunc func(ctx context.Context, method string, oprot io.Writer) error } -func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot athrift.TProtocol) error { +func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot io.Reader) error { if m.ReadFunc != nil { return m.ReadFunc(ctx, method, dataLen, oprot) } return nil } -func (m *mockWithContext) Write(ctx context.Context, method string, oprot athrift.TProtocol) error { +func (m *mockWithContext) Write(ctx context.Context, method string, oprot io.Writer) error { if m.WriteFunc != nil { return m.WriteFunc(ctx, method, oprot) } @@ -86,7 +86,7 @@ func TestWithContext(t *testing.T) { t.Run(tb.Name, func(t *testing.T) { ctx := context.Background() - req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot athrift.TProtocol) error { + req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot io.Writer) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") @@ -99,7 +99,7 @@ func TestWithContext(t *testing.T) { buf.Flush() { - resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot athrift.TProtocol) error { + resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot io.Reader) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") diff --git a/pkg/serviceinfo/serviceinfo.go b/pkg/serviceinfo/serviceinfo.go index c93126e48c..592c492d12 100644 --- a/pkg/serviceinfo/serviceinfo.go +++ b/pkg/serviceinfo/serviceinfo.go @@ -35,10 +35,6 @@ const ( GenericService = "$GenericService" // private as "$" // GenericMethod name GenericMethod = "$GenericCall" - // CombineService name - CombineService = "CombineService" - // CombineService_ is used when idl has a service named "CombineService" - CombineService_ = "CombineService_" // GenericClientStreamingMethod name GenericClientStreamingMethod = "$GenericClientStreamingMethod" // GenericServerStreamingMethod name From 34793dfd3ecdb38a2f0cd825417f76cd2b7e8c0b Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 1 Aug 2024 12:57:21 +0800 Subject: [PATCH 23/70] refactor: rm apache thrift in pkg/generic & netpollmux (#1470) --- go.mod | 2 +- go.sum | 4 +- pkg/generic/binary_test/generic_init.go | 2 +- pkg/generic/binary_test/generic_test.go | 2 +- pkg/generic/binarythrift_codec.go | 22 +++-- pkg/generic/reflect_test/reflect_test.go | 23 +++--- pkg/remote/codec/thrift/thrift.go | 4 +- pkg/remote/codec/thrift/thrift_frugal.go | 2 +- pkg/remote/trans/netpollmux/control_frame.go | 86 +++++++------------- pkg/utils/thrift.go | 4 +- 10 files changed, 59 insertions(+), 92 deletions(-) diff --git a/go.mod b/go.mod index d8f1283377..1afb570b3e 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.2.9 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.1.15 - github.com/cloudwego/gopkg v0.0.0-20240725095015-34d5327eebca + github.com/cloudwego/gopkg v0.1.0 github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index ab61497ef7..8fe9f91469 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.1.15 h1:LC55UJKhQPMFVjDPbE+LJcF7etZjSx6uokG1tk0wPK0= github.com/cloudwego/frugal v0.1.15/go.mod h1:26kU1r18vA8vRg12c66XPDlfv1GQHDbE1RpusipXfcI= -github.com/cloudwego/gopkg v0.0.0-20240725095015-34d5327eebca h1:xe6SuqnTHcqQlID29RG8gflr5pLKpffDJUusm7rZUPI= -github.com/cloudwego/gopkg v0.0.0-20240725095015-34d5327eebca/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.1.0 h1:N7CE4FS5crkZg3w7shw3UR3TG4+uofXXabGuBNmSrlE= +github.com/cloudwego/gopkg v0.1.0/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.0.9/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= diff --git a/pkg/generic/binary_test/generic_init.go b/pkg/generic/binary_test/generic_init.go index 05aece2460..e0ec11e595 100644 --- a/pkg/generic/binary_test/generic_init.go +++ b/pkg/generic/binary_test/generic_init.go @@ -191,7 +191,7 @@ func (m *MockImpl) ExceptionTest(ctx context.Context, req *kt.MockReq) (r string func genBinaryResp(method string) []byte { // no idea for respMsg part, it's not binary protocol. // DO NOT TOUCH IT or you may need to change the tests as well - n := thrift.Binary.MessageBeginLength(method, 0, 0) + len(respMsg) + n := thrift.Binary.MessageBeginLength(method) + len(respMsg) b := make([]byte, 0, n) b = thrift.Binary.AppendMessageBegin(b, method, 0, 100) b = append(b, respMsg...) diff --git a/pkg/generic/binary_test/generic_test.go b/pkg/generic/binary_test/generic_test.go index edfffd6fd5..da30c62a25 100644 --- a/pkg/generic/binary_test/generic_test.go +++ b/pkg/generic/binary_test/generic_test.go @@ -185,7 +185,7 @@ func initMockServer(handler kt.Mock) server.Server { func genBinaryReqBuf(method string) []byte { // no idea for reqMsg part, it's not binary protocol. // DO NOT TOUCH IT or you may need to change the tests as well - n := thrift.Binary.MessageBeginLength(method, 0, 0) + len(reqMsg) + n := thrift.Binary.MessageBeginLength(method) + len(reqMsg) b := make([]byte, 0, n) b = thrift.Binary.AppendMessageBegin(b, method, 0, 100) b = append(b, reqMsg...) diff --git a/pkg/generic/binarythrift_codec.go b/pkg/generic/binarythrift_codec.go index 99090248a5..c6aedd5862 100644 --- a/pkg/generic/binarythrift_codec.go +++ b/pkg/generic/binarythrift_codec.go @@ -21,12 +21,11 @@ import ( "encoding/binary" "fmt" - athrift "github.com/apache/thrift/lib/go/thrift" + gthrift "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -56,17 +55,16 @@ func (c *binaryThriftCodec) Marshal(ctx context.Context, msg remote.Message, out transBinary := gResult.Success // handle biz error if transBinary == nil { - tProt := thrift.NewBinaryProtocol(out) - if err := tProt.WriteMessageBegin(msg.RPCInfo().Invocation().MethodName(), athrift.TMessageType(msg.MessageType()), msg.RPCInfo().Invocation().SeqID()); err != nil { - return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("binary thrift generic marshal, WriteMessageBegin failed: %s", err.Error())) + sz := gthrift.Binary.MessageBeginLength(msg.RPCInfo().Invocation().MethodName()) + sz += gthrift.Binary.FieldStopLength() + b, err := out.Malloc(sz) + if err != nil { + return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err)) } - if err := tProt.WriteFieldStop(); err != nil { - return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("binary thrift generic marshal, WriteFieldStop failed: %s", err.Error())) - } - if err := tProt.WriteMessageEnd(); err != nil { - return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("binary thrift generic marshal, WriteMessageEnd failed: %s", err.Error())) - } - tProt.Recycle() + b = gthrift.Binary.AppendMessageBegin(b[:0], + msg.RPCInfo().Invocation().MethodName(), gthrift.TMessageType(msg.MessageType()), msg.RPCInfo().Invocation().SeqID()) + b = gthrift.Binary.AppendFieldStop(b) + _ = b return nil } else if transBuff, ok = transBinary.(binaryReqType); !ok { return perrors.NewProtocolErrorWithMsg("invalid marshal result in rawThriftBinaryCodec: must be []byte") diff --git a/pkg/generic/reflect_test/reflect_test.go b/pkg/generic/reflect_test/reflect_test.go index 36ec671475..714cebf2c4 100644 --- a/pkg/generic/reflect_test/reflect_test.go +++ b/pkg/generic/reflect_test/reflect_test.go @@ -34,7 +34,6 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/klog" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" @@ -205,7 +204,7 @@ func initClient(addr string) genericclient.Client { // Except msg, require_field and logid, which are reset everytime func makeExampleRespBinary(msg, require_field, logid string) ([]byte, error) { dom := &dg.PathNode{ - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -217,7 +216,7 @@ func makeExampleRespBinary(msg, require_field, logid string) ([]byte, error) { }, { Path: dg.NewPathFieldId(255), - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -236,7 +235,7 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { list := make([]dg.PathNode, SampleListSize+1) list[0] = dg.PathNode{ Path: dg.NewPathIndex(0), - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -247,7 +246,7 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { for i := 1; i < len(list); i++ { list[i] = dg.PathNode{ Path: dg.NewPathIndex(i), - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -259,7 +258,7 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { m := make([]dg.PathNode, SampleListSize+1) m[0] = dg.PathNode{ Path: dg.NewPathStrKey("a"), - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -270,7 +269,7 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { for i := 1; i < len(list); i++ { list[i] = dg.PathNode{ Path: dg.NewPathStrKey(strconv.Itoa(i)), - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -281,7 +280,7 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { } dom := dg.PathNode{ - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), @@ -293,17 +292,17 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { }, { Path: dg.NewPathFieldId(3), - Node: dg.NewTypedNode(thrift.LIST, thrift.STRUCT, 0), + Node: dg.NewTypedNode(dt.LIST, dt.STRUCT, 0), Next: list, }, { Path: dg.NewPathFieldId(4), - Node: dg.NewTypedNode(thrift.MAP, thrift.STRUCT, thrift.STRING), + Node: dg.NewTypedNode(dt.MAP, dt.STRUCT, dt.STRING), Next: m, }, { Path: dg.NewPathFieldId(6), - Node: dg.NewTypedNode(thrift.LIST, thrift.I64, 0), + Node: dg.NewTypedNode(dt.LIST, dt.I64, 0), Next: []dg.PathNode{ { Path: dg.NewPathIndex(0), @@ -325,7 +324,7 @@ func makeExampleReqBinary(B bool, A, logid string) ([]byte, error) { }, { Path: dg.NewPathFieldId(255), - Node: dg.NewTypedNode(thrift.STRUCT, 0, 0), + Node: dg.NewTypedNode(dt.STRUCT, 0, 0), Next: []dg.PathNode{ { Path: dg.NewPathFieldId(1), diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 8986abf03c..829767cd3b 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -170,7 +170,7 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg thrift.FastCodec) error { nw, _ := out.(remote.NocopyWrite) // nocopy write is a special implementation of linked buffer, only bytebuffer implement NocopyWrite do FastWrite - msgBeginLen := thrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) + msgBeginLen := thrift.Binary.MessageBeginLength(methodName) buf, err := out.Malloc(msgBeginLen + msg.BLength()) if err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error())) @@ -242,7 +242,7 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r // decode thrift data data := message.Data() - msgBeginLen := thrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) + msgBeginLen := thrift.Binary.MessageBeginLength(methodName) dataLen := message.PayloadLen() - msgBeginLen // For Buffer Protocol, dataLen would be negative. Set it to zero so as not to confuse if dataLen < 0 { diff --git a/pkg/remote/codec/thrift/thrift_frugal.go b/pkg/remote/codec/thrift/thrift_frugal.go index 6cda88f700..13be190133 100644 --- a/pkg/remote/codec/thrift/thrift_frugal.go +++ b/pkg/remote/codec/thrift/thrift_frugal.go @@ -62,7 +62,7 @@ func (c thriftCodec) hyperMarshal(out remote.ByteBuffer, methodName string, msgT seqID int32, data interface{}, ) error { // calculate and malloc message buffer - msgBeginLen := thrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) + msgBeginLen := thrift.Binary.MessageBeginLength(methodName) objectLen := frugal.EncodedSize(data) buf, err := out.Malloc(msgBeginLen + objectLen) if err != nil { diff --git a/pkg/remote/trans/netpollmux/control_frame.go b/pkg/remote/trans/netpollmux/control_frame.go index a913a656c6..942adb0835 100644 --- a/pkg/remote/trans/netpollmux/control_frame.go +++ b/pkg/remote/trans/netpollmux/control_frame.go @@ -26,7 +26,6 @@ import ( "fmt" "github.com/cloudwego/gopkg/protocol/thrift" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type ControlFrame struct{} @@ -35,74 +34,45 @@ func NewControlFrame() *ControlFrame { return &ControlFrame{} } -var fieldIDToName_ControlFrame = map[int16]string{} +func (p *ControlFrame) BLength() int { + return 1 +} -func (p *ControlFrame) Read(iprot athrift.TProtocol) (err error) { - var fieldTypeId athrift.TType - var fieldId int16 +func (p *ControlFrame) FastWrite(b []byte) int { return p.FastWriteNocopy(b, nil) } - if _, err = iprot.ReadStructBegin(); err != nil { - goto ReadStructBeginError - } +func (p *ControlFrame) FastWriteNocopy(b []byte, w thrift.NocopyWriter) int { + b[0] = 0 + return 1 +} +func (p *ControlFrame) FastRead(b []byte) (off int, err error) { + var ftyp thrift.TType + var fid int16 + var l int + x := thrift.BinaryProtocol{} for { - _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + ftyp, fid, l, err = x.ReadFieldBegin(b[off:]) + off += l if err != nil { goto ReadFieldBeginError } - if fieldTypeId == athrift.STOP { + if ftyp == thrift.STOP { break } - if err = iprot.Skip(fieldTypeId); err != nil { - goto SkipFieldTypeError - } - - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError + switch uint32(fid)<<8 | uint32(ftyp) { + default: + l, err = x.Skip(b[off:], ftyp) + off += l + if err != nil { + goto SkipFieldError + } } } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -SkipFieldTypeError: - return thrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return off, thrift.PrependError(fmt.Sprintf("%T read field begin error: ", p), err) +SkipFieldError: + return off, thrift.PrependError(fmt.Sprintf("%T skip field %d type %d error: ", p, fid, ftyp), err) } -func (p *ControlFrame) Write(oprot athrift.TProtocol) (err error) { - if err = oprot.WriteStructBegin("ControlFrame"); err != nil { - goto WriteStructBeginError - } - if p != nil { - } - if err = oprot.WriteFieldStop(); err != nil { - goto WriteFieldStopError - } - if err = oprot.WriteStructEnd(); err != nil { - goto WriteStructEndError - } - return nil -WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) -WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) -WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) -} - -func (p *ControlFrame) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("ControlFrame(%+v)", *p) -} +var _ thrift.FastCodec = &ControlFrame{} diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index 96a2a645e1..e8256337e3 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -76,7 +76,7 @@ func (t *ThriftMessageCodec) Decode(b []byte, msg athrift.TStruct) (method strin return } if msgType == athrift.EXCEPTION { - b = b[thrift.Binary.MessageBeginLength(method, 0, 0):] // for reusing fast read + b = b[thrift.Binary.MessageBeginLength(method):] // for reusing fast read ex := thrift.NewApplicationException(athrift.UNKNOWN_APPLICATION_EXCEPTION, "") if _, err = ex.FastRead(b); err != nil { return @@ -122,7 +122,7 @@ func (t *ThriftMessageCodec) Deserialize(msg athrift.TStruct, b []byte) (err err // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. func MarshalError(method string, err error) []byte { ex := thrift.NewApplicationException(athrift.INTERNAL_ERROR, err.Error()) - n := thrift.Binary.MessageBeginLength(method, 0, 0) + n := thrift.Binary.MessageBeginLength(method) n += ex.BLength() b := make([]byte, n) // Write message header From 30e34deb7a51d68f75a7c36129dbae7d10d76a68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ihc=E7=AB=A5=E9=9E=8B=40=E6=8F=90=E4=B8=8D=E8=B5=B7?= =?UTF-8?q?=E5=8A=B2?= Date: Thu, 1 Aug 2024 16:35:52 +0800 Subject: [PATCH 24/70] fix: allow HEADERS frame with empty header block fragment (#1466) --- pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go index 0fa6b4ce24..603bd81380 100644 --- a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go +++ b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go @@ -170,7 +170,8 @@ func parseHeadersFrame(_ *frameCache, fh http2.FrameHeader, p []byte) (_ http2.F return nil, err } } - if len(p)-int(padLength) <= 0 { + // note: len(p)-int(padLength) == 0 is valid + if len(p)-int(padLength) < 0 { return nil, http2.StreamError{StreamID: fh.StreamID, Code: http2.ErrCodeProtocol} } hf.headerFragBuf = p[:len(p)-int(padLength)] From 92d07aee2aaacd002cd3ae07cc82e8b4695f3b33 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 2 Aug 2024 20:05:16 +0800 Subject: [PATCH 25/70] refactor: rm apache thrift in internal/mocks (#1474) also: fix(server): invoker return err if no apache codec fix(server): listening on loopback addr --- client/mocks_test.go | 28 +-- internal/mocks/serviceinfo.go | 32 +-- internal/mocks/thrift_ttransport.go | 378 ---------------------------- internal/test/port.go | 2 +- pkg/remote/remotesvr/server_test.go | 2 +- server/invoke.go | 48 +++- server/invoke_test.go | 9 +- server/option_advanced_test.go | 8 +- server/option_test.go | 40 +-- server/server_test.go | 50 ++-- 10 files changed, 120 insertions(+), 477 deletions(-) delete mode 100644 internal/mocks/thrift_ttransport.go diff --git a/client/mocks_test.go b/client/mocks_test.go index 4064c70703..616047547e 100644 --- a/client/mocks_test.go +++ b/client/mocks_test.go @@ -16,28 +16,6 @@ package client -import ( - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -// MockTStruct implements the thrift.TStruct interface. -type MockTStruct struct { - WriteFunc func(p thrift.TProtocol) (e error) - ReadFunc func(p thrift.TProtocol) (e error) -} - -// Write implements the thrift.TStruct interface. -func (m MockTStruct) Write(p thrift.TProtocol) (e error) { - if m.WriteFunc != nil { - return m.WriteFunc(p) - } - return -} - -// Read implements the thrift.TStruct interface. -func (m MockTStruct) Read(p thrift.TProtocol) (e error) { - if m.ReadFunc != nil { - return m.ReadFunc(p) - } - return -} +// MockTStruct was implemented the thrift.TStruct interface. +// But actually Read/Write are not in use, so removed... only empty struct left for testing +type MockTStruct struct{} diff --git a/internal/mocks/serviceinfo.go b/internal/mocks/serviceinfo.go index 7d6f0020b5..a2f82c185c 100644 --- a/internal/mocks/serviceinfo.go +++ b/internal/mocks/serviceinfo.go @@ -21,7 +21,7 @@ import ( "errors" "fmt" - "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -195,13 +195,9 @@ type myServiceMockArgs struct { Req *MyRequest `thrift:"req,1" json:"req"` } -func (p *myServiceMockArgs) Read(iprot thrift.TProtocol) error { - return nil -} - -func (p *myServiceMockArgs) Write(oprot thrift.TProtocol) error { - return nil -} +func (p *myServiceMockArgs) BLength() int { return 1 } +func (p *myServiceMockArgs) FastWriteNocopy(buf []byte, bw thrift.NocopyWriter) int { return 1 } +func (p *myServiceMockArgs) FastRead(buf []byte) (int, error) { return 1, nil } // MyRequest . type MyRequest struct { @@ -212,13 +208,9 @@ type myServiceMockResult struct { Success *MyResponse `thrift:"success,0" json:"success,omitempty"` } -func (p *myServiceMockResult) Read(iprot thrift.TProtocol) error { - return nil -} - -func (p *myServiceMockResult) Write(oprot thrift.TProtocol) error { - return nil -} +func (p *myServiceMockResult) BLength() int { return 1 } +func (p *myServiceMockResult) FastWriteNocopy(buf []byte, bw thrift.NocopyWriter) int { return 1 } +func (p *myServiceMockResult) FastRead(buf []byte) (int, error) { return 1, nil } // MyResponse . type MyResponse struct { @@ -230,13 +222,11 @@ type myServiceMockExceptionResult struct { MyException *MyException `thrift:"stException,1" json:"stException,omitempty"` } -func (p *myServiceMockExceptionResult) Read(iprot thrift.TProtocol) error { - return nil -} - -func (p *myServiceMockExceptionResult) Write(oprot thrift.TProtocol) error { - return nil +func (p *myServiceMockExceptionResult) BLength() int { return 1 } +func (p *myServiceMockExceptionResult) FastWriteNocopy(buf []byte, bw thrift.NocopyWriter) int { + return 1 } +func (p *myServiceMockExceptionResult) FastRead(buf []byte) (int, error) { return 1, nil } // MyException . type MyException struct { diff --git a/internal/mocks/thrift_ttransport.go b/internal/mocks/thrift_ttransport.go deleted file mode 100644 index 58920453f1..0000000000 --- a/internal/mocks/thrift_ttransport.go +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Copyright 2021 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mocks - -import ( - "context" - - "github.com/apache/thrift/lib/go/thrift" -) - -type MockThriftTTransport struct { - WriteMessageBeginFunc func(name string, typeID thrift.TMessageType, seqID int32) error - WriteMessageEndFunc func() error - WriteStructBeginFunc func(name string) error - WriteStructEndFunc func() error - WriteFieldBeginFunc func(name string, typeID thrift.TType, id int16) error - WriteFieldEndFunc func() error - WriteFieldStopFunc func() error - WriteMapBeginFunc func(keyType, valueType thrift.TType, size int) error - WriteMapEndFunc func() error - WriteListBeginFunc func(elemType thrift.TType, size int) error - WriteListEndFunc func() error - WriteSetBeginFunc func(elemType thrift.TType, size int) error - WriteSetEndFunc func() error - WriteBoolFunc func(value bool) error - WriteByteFunc func(value int8) error - WriteI16Func func(value int16) error - WriteI32Func func(value int32) error - WriteI64Func func(value int64) error - WriteDoubleFunc func(value float64) error - WriteStringFunc func(value string) error - WriteBinaryFunc func(value []byte) error - ReadMessageBeginFunc func() (name string, typeID thrift.TMessageType, seqID int32, err error) - ReadMessageEndFunc func() error - ReadStructBeginFunc func() (name string, err error) - ReadStructEndFunc func() error - ReadFieldBeginFunc func() (name string, typeID thrift.TType, id int16, err error) - ReadFieldEndFunc func() error - ReadMapBeginFunc func() (keyType, valueType thrift.TType, size int, err error) - ReadMapEndFunc func() error - ReadListBeginFunc func() (elemType thrift.TType, size int, err error) - ReadListEndFunc func() error - ReadSetBeginFunc func() (elemType thrift.TType, size int, err error) - ReadSetEndFunc func() error - ReadBoolFunc func() (value bool, err error) - ReadByteFunc func() (value int8, err error) - ReadI16Func func() (value int16, err error) - ReadI32Func func() (value int32, err error) - ReadI64Func func() (value int64, err error) - ReadDoubleFunc func() (value float64, err error) - ReadStringFunc func() (value string, err error) - ReadBinaryFunc func() (value []byte, err error) - SkipFunc func(fieldType thrift.TType) (err error) - FlushFunc func(ctx context.Context) (err error) - TransportFunc func() thrift.TTransport -} - -func (m *MockThriftTTransport) WriteMessageBegin(name string, typeID thrift.TMessageType, seqID int32) error { - if m.WriteMessageBeginFunc != nil { - return m.WriteMessageBeginFunc(name, typeID, seqID) - } - return nil -} - -func (m *MockThriftTTransport) WriteMessageEnd() error { - if m.WriteMessageEndFunc != nil { - return m.WriteMessageEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteStructBegin(name string) error { - if m.WriteStructBeginFunc != nil { - return m.WriteStructBeginFunc(name) - } - return nil -} - -func (m *MockThriftTTransport) WriteStructEnd() error { - if m.WriteStructEndFunc != nil { - return m.WriteStructEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteFieldBegin(name string, typeID thrift.TType, id int16) error { - if m.WriteFieldBeginFunc != nil { - return m.WriteFieldBeginFunc(name, typeID, id) - } - return nil -} - -func (m *MockThriftTTransport) WriteFieldEnd() error { - if m.WriteFieldEndFunc != nil { - return m.WriteFieldEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteFieldStop() error { - if m.WriteFieldStopFunc != nil { - return m.WriteFieldStopFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteMapBegin(keyType, valueType thrift.TType, size int) error { - if m.WriteMapBeginFunc != nil { - return m.WriteMapBeginFunc(keyType, valueType, size) - } - return nil -} - -func (m *MockThriftTTransport) WriteMapEnd() error { - if m.WriteMapEndFunc != nil { - return m.WriteMapEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteListBegin(elemType thrift.TType, size int) error { - if m.WriteListBeginFunc != nil { - return m.WriteListBeginFunc(elemType, size) - } - return nil -} - -func (m *MockThriftTTransport) WriteListEnd() error { - if m.WriteListEndFunc != nil { - return m.WriteListEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteSetBegin(elemType thrift.TType, size int) error { - if m.WriteSetBeginFunc != nil { - return m.WriteSetBeginFunc(elemType, size) - } - return nil -} - -func (m *MockThriftTTransport) WriteSetEnd() error { - if m.WriteSetEndFunc != nil { - return m.WriteSetEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) WriteBool(value bool) error { - if m.WriteBoolFunc != nil { - return m.WriteBoolFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteByte(value int8) error { - if m.WriteByteFunc != nil { - return m.WriteByteFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteI16(value int16) error { - if m.WriteI16Func != nil { - return m.WriteI16Func(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteI32(value int32) error { - if m.WriteI32Func != nil { - return m.WriteI32Func(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteI64(value int64) error { - if m.WriteI64Func != nil { - return m.WriteI64Func(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteDouble(value float64) error { - if m.WriteDoubleFunc != nil { - return m.WriteDoubleFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteString(value string) error { - if m.WriteStringFunc != nil { - return m.WriteStringFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) WriteBinary(value []byte) error { - if m.WriteBinaryFunc != nil { - return m.WriteBinaryFunc(value) - } - return nil -} - -func (m *MockThriftTTransport) ReadMessageBegin() (name string, typeID thrift.TMessageType, seqID int32, err error) { - if m.ReadMessageBeginFunc != nil { - return m.ReadMessageBeginFunc() - } - return "", thrift.INVALID_TMESSAGE_TYPE, 0, nil -} - -func (m *MockThriftTTransport) ReadMessageEnd() error { - if m.ReadMessageEndFunc != nil { - return m.ReadMessageEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadStructBegin() (name string, err error) { - if m.ReadStructBeginFunc != nil { - return m.ReadStructBeginFunc() - } - return "", nil -} - -func (m *MockThriftTTransport) ReadStructEnd() error { - if m.ReadStructEndFunc != nil { - return m.ReadStructEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadFieldBegin() (name string, typeID thrift.TType, id int16, err error) { - if m.ReadFieldBeginFunc != nil { - return m.ReadFieldBeginFunc() - } - return "", thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadFieldEnd() error { - if m.ReadFieldEndFunc != nil { - return m.ReadFieldEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadMapBegin() (keyType, valueType thrift.TType, size int, err error) { - if m.ReadMapBeginFunc != nil { - return m.ReadMapBeginFunc() - } - return thrift.STOP, thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadMapEnd() error { - if m.ReadMapEndFunc != nil { - return m.ReadMapEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadListBegin() (elemType thrift.TType, size int, err error) { - if m.ReadListBeginFunc != nil { - return m.ReadListBeginFunc() - } - return thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadListEnd() error { - if m.ReadListEndFunc != nil { - return m.ReadListEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadSetBegin() (elemType thrift.TType, size int, err error) { - if m.ReadSetBeginFunc != nil { - return m.ReadSetBeginFunc() - } - return thrift.STOP, 0, nil -} - -func (m *MockThriftTTransport) ReadSetEnd() error { - if m.ReadSetEndFunc != nil { - return m.ReadSetEndFunc() - } - return nil -} - -func (m *MockThriftTTransport) ReadBool() (value bool, err error) { - if m.ReadBoolFunc != nil { - return m.ReadBoolFunc() - } - return false, nil -} - -func (m *MockThriftTTransport) ReadByte() (value int8, err error) { - if m.ReadByteFunc != nil { - return m.ReadByteFunc() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadI16() (value int16, err error) { - if m.ReadI16Func != nil { - return m.ReadI16Func() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadI32() (value int32, err error) { - if m.ReadI32Func != nil { - return m.ReadI32Func() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadI64() (value int64, err error) { - if m.ReadI64Func != nil { - return m.ReadI64Func() - } - return 0, nil -} - -func (m *MockThriftTTransport) ReadDouble() (value float64, err error) { - if m.ReadDoubleFunc != nil { - return m.ReadDoubleFunc() - } - return 0.0, nil -} - -func (m *MockThriftTTransport) ReadString() (value string, err error) { - if m.ReadStringFunc != nil { - return m.ReadStringFunc() - } - return "", nil -} - -func (m *MockThriftTTransport) ReadBinary() (value []byte, err error) { - if m.ReadBinaryFunc != nil { - return m.ReadBinaryFunc() - } - return nil, nil -} - -func (m *MockThriftTTransport) Skip(fieldType thrift.TType) (err error) { - if m.SkipFunc != nil { - return m.SkipFunc(fieldType) - } - return nil -} - -func (m *MockThriftTTransport) Flush(ctx context.Context) (err error) { - if m.FlushFunc != nil { - return m.FlushFunc(ctx) - } - return nil -} - -func (m *MockThriftTTransport) Transport() thrift.TTransport { - if m.TransportFunc != nil { - return m.TransportFunc() - } - return nil -} diff --git a/internal/test/port.go b/internal/test/port.go index 5a620c13f9..4193c67cd3 100644 --- a/internal/test/port.go +++ b/internal/test/port.go @@ -47,7 +47,7 @@ func GetLocalAddress() string { for { time.Sleep(time.Millisecond * time.Duration(1+rand.Intn(10))) port := atomic.AddUint32(&curPort, 1+uint32(rand.Intn(10))) - addr := "127.0.0.1:" + strconv.Itoa(int(port)) + addr := "localhost:" + strconv.Itoa(int(port)) if !IsAddressInUse(addr) { trace := strings.Split(string(debug.Stack()), "\n") if len(trace) > 6 { diff --git a/pkg/remote/remotesvr/server_test.go b/pkg/remote/remotesvr/server_test.go index 1c9a1f8a72..078e50dfbd 100644 --- a/pkg/remote/remotesvr/server_test.go +++ b/pkg/remote/remotesvr/server_test.go @@ -35,7 +35,7 @@ func TestServerStart(t *testing.T) { transSvr := &mocks.MockTransServer{ CreateListenerFunc: func(addr net.Addr) (listener net.Listener, err error) { isCreateListener = true - ln, err = net.Listen("tcp", ":8888") + ln, err = net.Listen("tcp", "localhost:8888") return ln, err }, BootstrapServerFunc: func(net.Listener) (err error) { diff --git a/server/invoke.go b/server/invoke.go index 8fada9d0c9..d26c8b2aa7 100644 --- a/server/invoke.go +++ b/server/invoke.go @@ -19,6 +19,7 @@ package server // Invoker is for calling handler function wrapped by Kitex suites without connection. import ( + "context" "errors" internal_server "github.com/cloudwego/kitex/internal/server" @@ -41,8 +42,42 @@ type Invoker interface { } type tInvoker struct { - invoke.Handler *server + + h invoke.Handler +} + +// invokerMetaDecoder is used to update `PayloadLen` of `remote.Message`. +// It fixes kitex returning err when apache codec is not available due to msg.PayloadLen() == 0. +// Because users may not add transport header like transport.Framed +// to invoke.Message when calling msg.SetRequestBytes. +// This is NOT expected and it's caused by kitex design fault. +type invokerMetaDecoder struct { + remote.Codec + + d remote.MetaDecoder +} + +func (d *invokerMetaDecoder) DecodeMeta(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + err := d.d.DecodeMeta(ctx, msg, in) + if err != nil { + return err + } + // cool ... no need to do anything. + // added transport header? + if msg.PayloadLen() > 0 { + return nil + } + // use the whole buffer + // coz for invoker remote.ByteBuffer always contains the whole msg payload + if n := in.ReadableLen(); n > 0 { + msg.SetPayloadLen(n) + } + return nil +} + +func (d *invokerMetaDecoder) DecodePayload(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + return d.d.DecodePayload(ctx, msg, in) } // NewInvoker creates new Invoker. @@ -51,6 +86,13 @@ func NewInvoker(opts ...Option) Invoker { opt: internal_server.NewOptions(opts), svcs: newServices(), } + if codec, ok := s.opt.RemoteOpt.Codec.(remote.MetaDecoder); ok { + // see comment on type `invokerMetaDecoder` + s.opt.RemoteOpt.Codec = &invokerMetaDecoder{ + Codec: s.opt.RemoteOpt.Codec, + d: codec, + } + } s.init() return &tInvoker{ server: s, @@ -69,7 +111,7 @@ func (s *tInvoker) Init() (err error) { doAddBoundHandler(transInfoHdlr, s.server.opt.RemoteOpt) } s.Lock() - s.Handler, err = s.newInvokeHandler() + s.h, err = s.newInvokeHandler() s.Unlock() if err != nil { return err @@ -82,7 +124,7 @@ func (s *tInvoker) Init() (err error) { // Call implements the InvokeCaller interface. func (s *tInvoker) Call(msg invoke.Message) error { - return s.Handler.Call(msg) + return s.h.Call(msg) } func (s *tInvoker) newInvokeHandler() (handler invoke.Handler, err error) { diff --git a/server/invoke_test.go b/server/invoke_test.go index 77089b4389..30a5178df7 100644 --- a/server/invoke_test.go +++ b/server/invoke_test.go @@ -22,11 +22,11 @@ import ( "sync/atomic" "testing" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote/trans/invoke" - "github.com/cloudwego/kitex/pkg/utils" ) // TestInvokerCall tests Invoker, call Kitex server just like SDK. @@ -47,10 +47,9 @@ func TestInvokerCall(t *testing.T) { } args := mocks.NewMockArgs() - codec := utils.NewThriftMessageCodec() // call success - b, _ := codec.Encode("mock", thrift.CALL, 0, args.(thrift.TStruct)) + b, _ := thrift.MarshalFastMsg("mock", thrift.CALL, 0, args.(thrift.FastCodec)) msg := invoke.NewMessage(nil, nil) err = msg.SetRequestBytes(b) test.Assert(t, err == nil) @@ -66,7 +65,7 @@ func TestInvokerCall(t *testing.T) { test.Assert(t, gotErr.Load() == nil) // call fails - b, _ = codec.Encode("mockError", thrift.CALL, 0, args.(thrift.TStruct)) + b, _ = thrift.MarshalFastMsg("mockError", thrift.CALL, 0, args.(thrift.FastCodec)) msg = invoke.NewMessage(nil, nil) err = msg.SetRequestBytes(b) test.Assert(t, err == nil) diff --git a/server/option_advanced_test.go b/server/option_advanced_test.go index 75ae1f64c1..66ba46de80 100644 --- a/server/option_advanced_test.go +++ b/server/option_advanced_test.go @@ -55,7 +55,7 @@ func TestACLRulesOption(t *testing.T) { return nil }) - svr := NewServer(WithACLRules(rules...)) + svr, _ := NewTestServer(WithACLRules(rules...)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -98,7 +98,7 @@ func (m *myLimitReporter) QPSOverloadReport() { // TestLimitReporterOption tests the creation of a server with LimitReporter option func TestLimitReporterOption(t *testing.T) { my := &myLimitReporter{} - svr := NewServer(WithLimitReporter(my)) + svr, _ := NewTestServer(WithLimitReporter(my)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -122,7 +122,7 @@ func TestGenericOptionPanic(t *testing.T) { // TestGenericOption tests the creation of a server with RemoteOpt.PayloadCodec option func TestGenericOption(t *testing.T) { g := generic.BinaryThriftGeneric() - svr := NewServer(WithGeneric(g)) + svr, _ := NewTestServer(WithGeneric(g)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -190,7 +190,7 @@ func TestWithBoundHandler(t *testing.T) { func TestExitSignalOption(t *testing.T) { stopSignal := make(chan error, 1) stopErr := errors.New("stop signal") - svr := NewServer(WithExitSignal(func() <-chan error { + svr, _ := NewTestServer(WithExitSignal(func() <-chan error { return stopSignal })) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) diff --git a/server/option_test.go b/server/option_test.go index e0d5302d10..ea7818a9cb 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -88,7 +88,7 @@ func TestOptionDebugInfo(t *testing.T) { // TestProxyOption tests the creation of a server with Proxy option func TestProxyOption(t *testing.T) { var opts []Option - addr, err := net.ResolveTCPAddr("tcp", ":8888") + addr, err := net.ResolveTCPAddr("tcp", "localhost:8888") test.Assert(t, err == nil, err) opts = append(opts, WithServiceAddr(addr)) opts = append(opts, WithProxy(&proxyMock{})) @@ -136,7 +136,7 @@ func (m *mockDiagnosis) ProbePairs() map[diagnosis.ProbeName]diagnosis.ProbeFunc func TestExitWaitTimeOption(t *testing.T) { // random timeout value testTimeOut := time.Duration(time.Now().Unix()) * time.Microsecond - svr := NewServer(WithExitWaitTime(testTimeOut)) + svr, _ := NewTestServer(WithExitWaitTime(testTimeOut)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -153,7 +153,7 @@ func TestExitWaitTimeOption(t *testing.T) { func TestMaxConnIdleTimeOption(t *testing.T) { // random timeout value testTimeOut := time.Duration(time.Now().Unix()) - svr := NewServer(WithMaxConnIdleTime(testTimeOut)) + svr, _ := NewTestServer(WithMaxConnIdleTime(testTimeOut)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -177,7 +177,7 @@ func (t *myTracer) Finish(ctx context.Context) { // TestTracerOption tests the creation of a server with TracerCtl option func TestTracerOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() err := svr1.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -191,7 +191,7 @@ func TestTracerOption(t *testing.T) { test.Assert(t, iSvr1.opt.TracerCtl.HasTracer() != true) tracer := &myTracer{} - svr2 := NewServer(WithTracer(tracer)) + svr2, _ := NewTestServer(WithTracer(tracer)) err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -207,7 +207,7 @@ func TestTracerOption(t *testing.T) { // TestStatsLevelOption tests the creation of a server with StatsLevel option func TestStatsLevelOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() err := svr1.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -220,7 +220,7 @@ func TestStatsLevelOption(t *testing.T) { test.Assert(t, iSvr1.opt.StatsLevel != nil) test.Assert(t, *iSvr1.opt.StatsLevel == stats.LevelDisabled) - svr2 := NewServer(WithStatsLevel(stats.LevelDetailed)) + svr2, _ := NewTestServer(WithStatsLevel(stats.LevelDetailed)) err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -243,7 +243,7 @@ func (s *mySuiteOption) Options() []Option { // TestSuiteOption tests the creation of a server with SuiteOption option func TestSuiteOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() time.AfterFunc(100*time.Millisecond, func() { err := svr1.Stop() test.Assert(t, err == nil, err) @@ -264,7 +264,7 @@ func TestSuiteOption(t *testing.T) { WithExitWaitTime(tmpWaitTime), WithMaxConnIdleTime(tmpConnIdleTime), }} - svr2 := NewServer(WithSuite(suiteOpt)) + svr2, _ := NewTestServer(WithSuite(suiteOpt)) time.AfterFunc(100*time.Millisecond, func() { err := svr2.Stop() test.Assert(t, err == nil, err) @@ -283,7 +283,7 @@ func TestSuiteOption(t *testing.T) { // TestMuxTransportOption tests the creation of a server,with netpollmux remote.ServerTransHandlerFactory option, func TestMuxTransportOption(t *testing.T) { - svr1 := NewServer() + svr1, _ := NewTestServer() time.AfterFunc(100*time.Millisecond, func() { err := svr1.Stop() test.Assert(t, err == nil, err) @@ -295,7 +295,7 @@ func TestMuxTransportOption(t *testing.T) { iSvr1 := svr1.(*server) test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory())) - svr2 := NewServer(WithMuxTransport()) + svr2, _ := NewTestServer(WithMuxTransport()) time.AfterFunc(100*time.Millisecond, func() { err := svr2.Stop() test.Assert(t, err == nil, err) @@ -312,7 +312,7 @@ func TestMuxTransportOption(t *testing.T) { // TestPayloadCodecOption tests the creation of a server with RemoteOpt.PayloadCodec option func TestPayloadCodecOption(t *testing.T) { t.Run("NotSetPayloadCodec", func(t *testing.T) { - svr := NewServer() + svr, _ := NewTestServer() time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -337,7 +337,7 @@ func TestPayloadCodecOption(t *testing.T) { test.Assert(t, protobuf.IsProtobufCodec(pc)) }) t.Run("SetPreRegisteredProtobufCodec", func(t *testing.T) { - svr := NewServer(WithPayloadCodec(protobuf.NewProtobufCodec())) + svr, _ := NewTestServer(WithPayloadCodec(protobuf.NewProtobufCodec())) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -364,7 +364,7 @@ func TestPayloadCodecOption(t *testing.T) { t.Run("SetPreRegisteredThriftCodec", func(t *testing.T) { thriftCodec := thrift.NewThriftCodecDisableFastMode(false, true) - svr := NewServer(WithPayloadCodec(thriftCodec)) + svr, _ := NewTestServer(WithPayloadCodec(thriftCodec)) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -393,7 +393,7 @@ func TestPayloadCodecOption(t *testing.T) { t.Run("SetNonPreRegisteredCodec", func(t *testing.T) { // generic.BinaryThriftGeneric().PayloadCodec() is not the pre registered codec, RemoteOpt.PayloadCodec won't be nil binaryThriftCodec := generic.BinaryThriftGeneric().PayloadCodec() - svr := NewServer(WithPayloadCodec(binaryThriftCodec)) + svr, _ := NewTestServer(WithPayloadCodec(binaryThriftCodec)) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -431,7 +431,7 @@ func TestRemoteOptGRPCCfgUintValueOption(t *testing.T) { randUint3 := uint32(rand.Int31n(100)) + 1 randUint4 := uint32(rand.Int31n(100)) + 1 - svr1 := NewServer( + svr1, _ := NewTestServer( WithGRPCInitialWindowSize(randUint1), WithGRPCInitialConnWindowSize(randUint2), WithGRPCMaxConcurrentStreams(randUint3), @@ -462,7 +462,7 @@ func TestGRPCKeepaliveEnforcementPolicyOption(t *testing.T) { MinTime: time.Duration(randInt) * time.Second, PermitWithoutStream: true, } - svr1 := NewServer( + svr1, _ := NewTestServer( WithGRPCKeepaliveEnforcementPolicy(kep), ) @@ -493,7 +493,7 @@ func TestGRPCKeepaliveParamsOption(t *testing.T) { Time: randTimeDuration4, Timeout: randTimeDuration5, } - svr1 := NewServer( + svr1, _ := NewTestServer( WithGRPCKeepaliveParams(kp), ) @@ -516,7 +516,7 @@ func TestWithProfilerMessageTagging(t *testing.T) { var msgTagging2 remote.MessageTagging = func(ctx context.Context, msg remote.Message) (context.Context, []string) { return context.WithValue(ctx, "ctx2", 2), []string{"b", "2", "c", "2"} } - svr := NewServer(WithProfilerMessageTagging(msgTagging1), WithProfilerMessageTagging(msgTagging2)) + svr, _ := NewTestServer(WithProfilerMessageTagging(msgTagging1), WithProfilerMessageTagging(msgTagging2)) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { @@ -543,7 +543,7 @@ func TestWithProfilerMessageTagging(t *testing.T) { } func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) { - svr := NewServer(WithRefuseTrafficWithoutServiceName()) + svr, _ := NewTestServer(WithRefuseTrafficWithoutServiceName()) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil, err) time.AfterFunc(100*time.Millisecond, func() { diff --git a/server/server_test.go b/server/server_test.go index e9b71f5e94..ede109d170 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -55,10 +55,26 @@ import ( var svcInfo = mocks.ServiceInfo() +// NOTE: always use this method to get addr for server listening +// should be used with `WithServiceAddr(addr)` +func getAddrForListener() net.Addr { + addr := test.GetLocalAddress() + ret, _ := net.ResolveTCPAddr("tcp", addr) + return ret +} + +// NewTestServer calls NewServer with a random addr +// DO NOT USE `NewServer` and `s.Run()` without specifying addr, it listens on :8888 ... +func NewTestServer(ops ...Option) (Server, net.Addr) { + addr := getAddrForListener() + svr := NewServer(append(ops, WithServiceAddr(addr))...) + return svr, addr +} + func TestServerRun(t *testing.T) { var opts []Option opts = append(opts, WithMetaHandler(noopMetahandler{})) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) time.AfterFunc(time.Millisecond*500, func() { err := svr.Stop() @@ -84,8 +100,7 @@ func TestServerRun(t *testing.T) { } func TestReusePortServerRun(t *testing.T) { - hostPort := test.GetLocalAddress() - addr, _ := net.ResolveTCPAddr("tcp", hostPort) + addr := getAddrForListener() var opts []Option opts = append(opts, WithReusePort(true)) opts = append(opts, WithServiceAddr(addr), WithExitWaitTime(time.Microsecond*10)) @@ -261,7 +276,7 @@ func TestServiceRegisterFailed(t *testing.T) { } var opts []Option opts = append(opts, WithRegistry(mockRegistry)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -287,7 +302,7 @@ func TestServiceDeregisterFailed(t *testing.T) { } var opts []Option opts = append(opts, WithRegistry(mockRegistry)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -308,7 +323,6 @@ func TestServiceRegistryInfo(t *testing.T) { checkInfo := func(info *registry.Info) { test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) test.Assert(t, info.Weight == registryInfo.Weight, info.Addr) - test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) test.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags) test.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags) } @@ -329,7 +343,7 @@ func TestServiceRegistryInfo(t *testing.T) { var opts []Option opts = append(opts, WithRegistry(mockRegistry)) opts = append(opts, WithRegistryInfo(registryInfo)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -346,7 +360,6 @@ func TestServiceRegistryInfo(t *testing.T) { func TestServiceRegistryNoInitInfo(t *testing.T) { checkInfo := func(info *registry.Info) { test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) - test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) } var rCount int var drCount int @@ -364,7 +377,7 @@ func TestServiceRegistryNoInitInfo(t *testing.T) { } var opts []Option opts = append(opts, WithRegistry(mockRegistry)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -386,7 +399,6 @@ func TestServiceRegistryInfoWithNilTags(t *testing.T) { } checkInfo := func(info *registry.Info) { test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) - test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) test.Assert(t, info.Weight == registryInfo.Weight, info.Weight) test.Assert(t, info.Tags["aa"] == "bb", info.Tags) } @@ -410,7 +422,7 @@ func TestServiceRegistryInfoWithNilTags(t *testing.T) { opts = append(opts, WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ Tags: map[string]string{"aa": "bb"}, })) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -453,7 +465,7 @@ func TestServiceRegistryInfoWithSkipListenAddr(t *testing.T) { var opts []Option opts = append(opts, WithRegistry(mockRegistry)) opts = append(opts, WithRegistryInfo(registryInfo)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -495,7 +507,7 @@ func TestServiceRegistryInfoWithoutSkipListenAddr(t *testing.T) { var opts []Option opts = append(opts, WithRegistry(mockRegistry)) opts = append(opts, WithRegistryInfo(registryInfo)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) @@ -512,7 +524,7 @@ func TestServiceRegistryInfoWithoutSkipListenAddr(t *testing.T) { func TestGRPCServerMultipleServices(t *testing.T) { var opts []Option opts = append(opts, withGRPCTransport()) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) test.Assert(t, err == nil) err = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler()) @@ -672,7 +684,7 @@ func TestServerBoundHandler(t *testing.T) { } for _, tcase := range cases { opts := append(tcase.opts, WithExitWaitTime(time.Millisecond*10)) - svr := NewServer(opts...) + svr, _ := NewTestServer(opts...) time.AfterFunc(100*time.Millisecond, func() { err := svr.Stop() @@ -691,9 +703,9 @@ func TestServerBoundHandler(t *testing.T) { } func TestInvokeHandlerWithContextBackup(t *testing.T) { - testInvokeHandlerWithSession(t, true, ":8888") + testInvokeHandlerWithSession(t, true, "localhost:8888") os.Setenv(localsession.SESSION_CONFIG_KEY, "true,100,1h") - testInvokeHandlerWithSession(t, false, ":8889") + testInvokeHandlerWithSession(t, false, "localhost:8889") } func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { @@ -840,7 +852,7 @@ func TestInvokeHandlerExec(t *testing.T) { }, CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { var err error - ln, err = net.Listen("tcp", ":8888") + ln, err = net.Listen("tcp", "localhost:8888") return ln, err }, } @@ -903,7 +915,7 @@ func TestInvokeHandlerPanic(t *testing.T) { }, CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { var err error - ln, err = net.Listen("tcp", ":8888") + ln, err = net.Listen("tcp", "localhost:8888") return ln, err }, } From 1727014086796971dcc0c81c32b25fa515d9d1f9 Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Fri, 2 Aug 2024 20:21:22 +0800 Subject: [PATCH 26/70] chore: remove github.com/stretchr/testify direct dependency (#1475) --- CREDITS | 1 - go.mod | 2 +- internal/generic/proto/json_test.go | 12 +++++------- internal/generic/thrift/read_test.go | 4 ++-- pkg/utils/json_fuzz_test.go | 25 +++++++++++++------------ 5 files changed, 21 insertions(+), 23 deletions(-) delete mode 100644 CREDITS diff --git a/CREDITS b/CREDITS deleted file mode 100644 index fa01335116..0000000000 --- a/CREDITS +++ /dev/null @@ -1 +0,0 @@ -github.com/stretchr/testify \ No newline at end of file diff --git a/go.mod b/go.mod index 1afb570b3e..c86898ad6f 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 github.com/jhump/protoreflect v1.8.2 github.com/json-iterator/go v1.1.12 - github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.9.3 golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 @@ -45,6 +44,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oleiade/lane v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/internal/generic/proto/json_test.go b/internal/generic/proto/json_test.go index 748a17e5c7..1bec69eb12 100644 --- a/internal/generic/proto/json_test.go +++ b/internal/generic/proto/json_test.go @@ -20,12 +20,12 @@ import ( "context" "encoding/json" "io/ioutil" + "reflect" "testing" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/proto" "github.com/cloudwego/dynamicgo/testdata/kitex_gen/pb/example2" - "github.com/stretchr/testify/require" goprotowire "google.golang.org/protobuf/encoding/protowire" "github.com/cloudwego/kitex/internal/test" @@ -76,14 +76,13 @@ func TestWrite(t *testing.T) { l += tagLen buf = buf[tagLen:] offset, err := act.FastRead(buf, int8(wtyp), int32(id)) - require.Nil(t, err) + test.Assert(t, err == nil) buf = buf[offset:] l += offset } test.Assert(t, err == nil) - // compare exp and act struct - require.Equal(t, exp, act) + test.Assert(t, reflect.DeepEqual(exp, act)) } // Check NewReadJSON converting protobuf wire format to JSON @@ -114,7 +113,7 @@ func TestRead(t *testing.T) { l += tagLen in = in[tagLen:] offset, err := exp.FastRead(in, int8(wtyp), int32(id)) - require.Nil(t, err) + test.Assert(t, err == nil) in = in[offset:] l += offset } @@ -127,9 +126,8 @@ func TestRead(t *testing.T) { str, ok := out.(string) test.Assert(t, ok) json.Unmarshal([]byte(str), &act) - // compare exp and act struct - require.Equal(t, exp, act) + test.Assert(t, reflect.DeepEqual(exp, act)) } // helper methods diff --git a/internal/generic/thrift/read_test.go b/internal/generic/thrift/read_test.go index e50030d33a..2a77e120ee 100644 --- a/internal/generic/thrift/read_test.go +++ b/internal/generic/thrift/read_test.go @@ -25,9 +25,9 @@ import ( "github.com/cloudwego/gopkg/protocol/thrift" "github.com/jhump/protoreflect/desc/protoparse" - "github.com/stretchr/testify/require" "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/remote" ) @@ -656,7 +656,7 @@ func Test_readStruct(t *testing.T) { t.Errorf("readStruct() error = %v, wantErr %v", err, tt.wantErr) return } - require.Equal(t, tt.want, got) + test.Assert(t, reflect.DeepEqual(tt.want, got)) }) } } diff --git a/pkg/utils/json_fuzz_test.go b/pkg/utils/json_fuzz_test.go index c09851947f..40b5b3b720 100644 --- a/pkg/utils/json_fuzz_test.go +++ b/pkg/utils/json_fuzz_test.go @@ -21,9 +21,10 @@ package utils import ( "encoding/json" + "reflect" "testing" - "github.com/stretchr/testify/require" + "github.com/cloudwego/kitex/internal/test" ) func FuzzJSONStr2Map(f *testing.F) { @@ -42,10 +43,9 @@ func FuzzJSONStr2Map(f *testing.F) { } map1, err1 := JSONStr2Map(data) map2, err2 := _JSONStr2Map(data) - require.Equal(t, err2 == nil, err1 == nil, "json:%v", data) - if err2 == nil { - require.Equal(t, map2, map1, "json:%v", data) - } + test.Assert(t, err1 == nil, data) + test.Assert(t, err2 == nil, data) + test.Assert(t, reflect.DeepEqual(map1, map2), data) }) } @@ -64,13 +64,14 @@ func FuzzMap2JSON(f *testing.F) { if err := json.Unmarshal([]byte(data), &m); err != nil { return } - map1, err1 := Map2JSONStr(m) - map2, err2 := _Map2JSONStr(m) - require.Equal(t, err2 == nil, err1 == nil, "json:%v", data) - require.Equal(t, len(map2), len(map1), "json:%v", data) + str1, err1 := Map2JSONStr(m) + str2, err2 := _Map2JSONStr(m) + test.Assert(t, err1 == nil, data) + test.Assert(t, err2 == nil, data) + test.Assert(t, len(str1) == len(str2)) var m1, m2 map[string]string - require.NoError(t, json.Unmarshal([]byte(map1), &m1)) - require.NoError(t, json.Unmarshal([]byte(map2), &m2)) - require.Equal(t, m2, m1) + test.Assert(t, json.Unmarshal([]byte(str1), &m1) == nil) + test.Assert(t, json.Unmarshal([]byte(str2), &m2) == nil) + test.Assert(t, reflect.DeepEqual(m1, m2)) }) } From 19d1a489a8904ffe9508581361d99d2a14f72837 Mon Sep 17 00:00:00 2001 From: Joway Date: Mon, 5 Aug 2024 16:47:27 +0800 Subject: [PATCH 27/70] chore: upgrade gopkg to v0.1.0 (#1477) --- go.mod | 2 +- go.sum | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index c86898ad6f..ab45d6f291 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/apache/thrift v0.13.0 - github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8 + github.com/bytedance/gopkg v0.1.0 github.com/bytedance/sonic v1.11.8 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.2.9 diff --git a/go.sum b/go.sum index 8fe9f91469..586c0212f6 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,9 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8 h1:rDwLxYTMoKHaw4cS0bQhaTZnkXp5e6ediCggGcRD/CA= github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= +github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.8 h1:Zw/j1KfiS+OYTi9lyB3bb0CFxPJVkM17k1wyDG32LRA= github.com/bytedance/sonic v1.11.8/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= From d768d53b32deb0ba78d48bc21c24da54fcc9dc6f Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:09:14 +0800 Subject: [PATCH 28/70] chore(generic): add generic base using gopkg base (#1482) --- go.mod | 2 +- go.sum | 4 +-- internal/mocks/generic/thrift.go | 1 - pkg/generic/thrift/base.go | 49 ++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 pkg/generic/thrift/base.go diff --git a/go.mod b/go.mod index ab45d6f291..9576b5cad0 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.2.9 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.1.15 - github.com/cloudwego/gopkg v0.1.0 + github.com/cloudwego/gopkg v0.1.1-0.20240805070331-9ef090c57b1f github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 586c0212f6..c1730cb068 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.1.15 h1:LC55UJKhQPMFVjDPbE+LJcF7etZjSx6uokG1tk0wPK0= github.com/cloudwego/frugal v0.1.15/go.mod h1:26kU1r18vA8vRg12c66XPDlfv1GQHDbE1RpusipXfcI= -github.com/cloudwego/gopkg v0.1.0 h1:N7CE4FS5crkZg3w7shw3UR3TG4+uofXXabGuBNmSrlE= -github.com/cloudwego/gopkg v0.1.0/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.1.1-0.20240805070331-9ef090c57b1f h1:65h7Qmcnnw/90O1U56uuSicKztJn4Is6eK2KcktST2Y= +github.com/cloudwego/gopkg v0.1.1-0.20240805070331-9ef090c57b1f/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.0.9/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index 142fd9f483..ae32354f04 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -27,7 +27,6 @@ import ( reflect "reflect" "github.com/cloudwego/gopkg/protocol/thrift/base" - gomock "github.com/golang/mock/gomock" ) diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go new file mode 100644 index 0000000000..ae930c3307 --- /dev/null +++ b/pkg/generic/thrift/base.go @@ -0,0 +1,49 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import "github.com/cloudwego/gopkg/protocol/thrift/base" + +// TrafficEnv ... +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base +type TrafficEnv = base.TrafficEnv + +// NewTrafficEnv ... +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base +func NewTrafficEnv() *TrafficEnv { + return base.NewTrafficEnv() +} + +// Base ... +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base +type Base = base.Base + +// NewBase ... +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base +func NewBase() *Base { + return base.NewBase() +} + +// BaseResp ... +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base +type BaseResp = base.BaseResp + +// NewBaseResp ... +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base +func NewBaseResp() *BaseResp { + return base.NewBaseResp() +} From 5e1c3f5d75807369dd71a5131a4b098aa134ec18 Mon Sep 17 00:00:00 2001 From: Jayant Date: Tue, 6 Aug 2024 15:34:18 +0800 Subject: [PATCH 29/70] fix(gonet): adjust gonet server read timeout to avoid read error (#1481) --- pkg/remote/trans/gonet/trans_server.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/remote/trans/gonet/trans_server.go b/pkg/remote/trans/gonet/trans_server.go index 095b2dca2f..aae4f8811a 100644 --- a/pkg/remote/trans/gonet/trans_server.go +++ b/pkg/remote/trans/gonet/trans_server.go @@ -138,8 +138,12 @@ func (ts *transServer) onError(ctx context.Context, err error, conn net.Conn) { } func (ts *transServer) refreshDeadline(ri rpcinfo.RPCInfo, conn net.Conn) { - readTimeout := ri.Config().ReadWriteTimeout() - _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) + // ReadWriteTimeout indicates the maximum time to read a message from the connection since it received a read event, + // so it's not suitable for bio mode like gonet, modify the default setting to 2 minutes to make sure it's greater + // than the client idle timeout. + + // readTimeout := ri.Config().ReadWriteTimeout() + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Minute)) } // bufioConn implements the net.Conn interface. From 56f75f8d4330b23ca57bfcf2aea81bad92a8bc9e Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Tue, 6 Aug 2024 18:57:08 +0800 Subject: [PATCH 30/70] refactor: use github.com/cloudwego/gopkg/protocol/thrift/apache (#1483) * added README.md for pkg/protocol/bthrift * pkg/protocol/bthrift: not import kitex other pkgs * pkg/protocol/bthrift/apache: clear up -> thrift.go * pkg/generic: no apache codec * pkg/remote/codec/thrift: deprecate BinaryProtocol * pkg/utils: deprecate apache codec * tool: add `thrift_import_path` by default --- .gitignore | 4 + go.mod | 2 +- go.sum | 4 +- pkg/generic/map_test/generic_init.go | 11 +- pkg/protocol/bthrift/README.md | 24 ++++ pkg/protocol/bthrift/apache/apache.go | 30 ++--- .../bthrift/apache/application_exception.go | 39 ------ .../bthrift/apache/binary_protocol.go | 23 ---- pkg/protocol/bthrift/apache/memory_buffer.go | 57 --------- pkg/protocol/bthrift/apache/messagetype.go | 32 ----- pkg/protocol/bthrift/apache/protocol.go | 33 ----- .../bthrift/apache/protocol_exception.go | 33 ----- pkg/protocol/bthrift/apache/serializer.go | 24 ---- pkg/protocol/bthrift/apache/thrift.go | 84 +++++++++++++ pkg/protocol/bthrift/apache/transport.go | 23 ---- pkg/protocol/bthrift/apache/type.go | 43 ------- pkg/protocol/bthrift/binary.go | 119 ++++++++++-------- pkg/protocol/bthrift/binary_test.go | 2 +- .../protocol/bthrift/compat.go | 6 +- pkg/protocol/bthrift/internal/test/README.md | 6 + pkg/protocol/bthrift/internal/test/assert.go | 92 ++++++++++++++ .../bthrift/internal/test/assert_test.go | 105 ++++++++++++++++ pkg/protocol/bthrift/unknown_test.go | 2 +- .../{binary_protocol.go => deprecated.go} | 95 ++++++++------ ...ry_protocol_test.go => deprecated_test.go} | 0 pkg/remote/codec/thrift/thrift.go | 43 +++---- pkg/remote/codec/thrift/thrift_data.go | 95 +++++--------- pkg/remote/codec/thrift/thrift_data_test.go | 50 +++----- pkg/remote/codec/thrift/thrift_frugal_test.go | 10 +- pkg/remote/codec/thrift/thrift_test.go | 11 +- pkg/utils/thrift.go | 93 +++++--------- pkg/utils/thrift_test.go | 15 +-- tool/cmd/kitex/args/args.go | 5 + 33 files changed, 580 insertions(+), 635 deletions(-) create mode 100644 pkg/protocol/bthrift/README.md delete mode 100644 pkg/protocol/bthrift/apache/application_exception.go delete mode 100644 pkg/protocol/bthrift/apache/binary_protocol.go delete mode 100644 pkg/protocol/bthrift/apache/memory_buffer.go delete mode 100644 pkg/protocol/bthrift/apache/messagetype.go delete mode 100644 pkg/protocol/bthrift/apache/protocol.go delete mode 100644 pkg/protocol/bthrift/apache/protocol_exception.go delete mode 100644 pkg/protocol/bthrift/apache/serializer.go create mode 100644 pkg/protocol/bthrift/apache/thrift.go delete mode 100644 pkg/protocol/bthrift/apache/transport.go delete mode 100644 pkg/protocol/bthrift/apache/type.go rename internal/mocks/thrift/utils.go => pkg/protocol/bthrift/compat.go (95%) create mode 100644 pkg/protocol/bthrift/internal/test/README.md create mode 100644 pkg/protocol/bthrift/internal/test/assert.go create mode 100644 pkg/protocol/bthrift/internal/test/assert_test.go rename pkg/remote/codec/thrift/{binary_protocol.go => deprecated.go} (79%) rename pkg/remote/codec/thrift/{binary_protocol_test.go => deprecated_test.go} (100%) diff --git a/.gitignore b/.gitignore index f3872df9e2..da02d92ac7 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,7 @@ tool/cmd/kitex/kitex base1.go dump.txt base2.go + +# Go workspace file +go.work +go.work.sum diff --git a/go.mod b/go.mod index 9576b5cad0..72b9a05c98 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.2.9 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.1.15 - github.com/cloudwego/gopkg v0.1.1-0.20240805070331-9ef090c57b1f + github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8 github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index c1730cb068..8a5fe20fc2 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.1.15 h1:LC55UJKhQPMFVjDPbE+LJcF7etZjSx6uokG1tk0wPK0= github.com/cloudwego/frugal v0.1.15/go.mod h1:26kU1r18vA8vRg12c66XPDlfv1GQHDbE1RpusipXfcI= -github.com/cloudwego/gopkg v0.1.1-0.20240805070331-9ef090c57b1f h1:65h7Qmcnnw/90O1U56uuSicKztJn4Is6eK2KcktST2Y= -github.com/cloudwego/gopkg v0.1.1-0.20240805070331-9ef090c57b1f/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8 h1:kQPjddHw5Dufci/vfiRGMN3Uhx12XWqNpk1JdQ4Tjy0= +github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.0.9/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= diff --git a/pkg/generic/map_test/generic_init.go b/pkg/generic/map_test/generic_init.go index aae9490f27..2b92a303a3 100644 --- a/pkg/generic/map_test/generic_init.go +++ b/pkg/generic/map_test/generic_init.go @@ -36,6 +36,7 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/genericserver" + "github.com/cloudwego/kitex/transport" ) var reqMsg = map[string]interface{}{ @@ -54,7 +55,7 @@ var errResp = "Test Error" func newGenericClient(destService string, g generic.Generic, targetIPPort string) genericclient.Client { var opts []client.Option - opts = append(opts, client.WithHostPorts(targetIPPort)) + opts = append(opts, client.WithHostPorts(targetIPPort), client.WithTransportProtocol(transport.TTHeader)) genericCli, _ := genericclient.NewClient(destService, g, opts...) return genericCli } @@ -204,16 +205,16 @@ func serviceInfo() *serviceinfo.ServiceInfo { } func newMockTestArgs() interface{} { - return kt.ToApacheCodec(kt.NewMockTestArgs()) + return kt.NewMockTestArgs() } func newMockTestResult() interface{} { - return kt.ToApacheCodec(kt.NewMockTestResult()) + return kt.NewMockTestResult() } func testHandler(ctx context.Context, handler, arg, result interface{}) error { - realArg := kt.UnpackApacheCodec(arg).(*kt.MockTestArgs) - realResult := kt.UnpackApacheCodec(result).(*kt.MockTestResult) + realArg := arg.(*kt.MockTestArgs) + realResult := result.(*kt.MockTestResult) success, err := handler.(kt.Mock).Test(ctx, realArg.Req) if err != nil { return err diff --git a/pkg/protocol/bthrift/README.md b/pkg/protocol/bthrift/README.md new file mode 100644 index 0000000000..d90609fab2 --- /dev/null +++ b/pkg/protocol/bthrift/README.md @@ -0,0 +1,24 @@ +# bthrift + +`bthrift` is no longer used, but the legacy generated code may still rely on it. For newly added code, should use `github.com/cloudwego/gopkg/protocol/thrift` instead. + +We're planning to get rid of `github.com/apache/thrift`, here are steps we have done: +1. Removed unnecessary dependencies of apache from kitex +2. Moved all apache dependencies to `bthrift/apache`, mainly types, interfaces and consts + - We may use type alias at the beginning for better compatibility + - `bthrift/apache`calls `apache.RegisterNewTBinaryProtocol` in `gopkg` for step 4 +3. For internal dependencies of apache, use `gopkg` +4. For existing framework code working with apache thrift: + - Use `gopkg/protocol/thrift/apache` +5. For Kitex tool: + - Use `gopkg/protocol/thrift` for fastcodec + - replace `github.com/apache/thrift` with `bthrift/apache` + - by using `thrift_import_path` parameter of thriftgo + +The final step we planned to do in version v0.12.0: +* Add go.mod for `bthrift` +* Remove the last `github.com/apache/thrift` dependencies + * `ThriftMessageCodec` of `pkg/utils` + * `MessageReader` and `MessageWriter` interfaces in `pkg/remote/codec/thrift` + * `BinaryProtocol` type in `pkg/remote/codec/thrift` + * basic codec tests in `pkg/remote/codec/thrift` diff --git a/pkg/protocol/bthrift/apache/apache.go b/pkg/protocol/bthrift/apache/apache.go index ca7ee5d96c..4537fa1b57 100644 --- a/pkg/protocol/bthrift/apache/apache.go +++ b/pkg/protocol/bthrift/apache/apache.go @@ -14,24 +14,14 @@ * limitations under the License. */ -// Package apache contains codes originally from https://github.com/apache/thrift. -// -// we're planning to get rid of the pkg, here are steps we're going to work on: -// 1. Remove unnecessary dependencies of apache from kitex -// 2. Move all apache dependencies to this pkg, mainly types, interfaces and consts -// - We may use type alias at the beginning for better compatibility -// - Mark interfaces as `Deprecated` since we no longer use it in the future, and we have better implementation. -// 3. For internal dependencies of apache, new alternative implementation will be in: -// - pkg/protocol/bthrift -> low level encoding or decoding bytes -// - pkg/remote/codec/thrift -> high level interfaces -// 4. Change necessary dependencies to this file, including code generator -// (After a period of time) -// 5. Remove apache support of code generator (mainly interfaces) -// 6. Remove type alias and move definition to this file. This may causes compatible issues which are expected. -// - legacy code generator should use legacy version of kitex, then should not have compatibility issue. -// (After a period of time) -// 7. Remove interfaces like thrift.TProtocol from this file -// 8. Done -// -// Now we're working on step 1 - 4. package apache + +import ( + "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/apache" +) + +func init() { + // it makes github.com/cloudwego/gopkg/protocol/thrift/apache works + _ = apache.RegisterNewTBinaryProtocol(thrift.NewTBinaryProtocol) +} diff --git a/pkg/protocol/bthrift/apache/application_exception.go b/pkg/protocol/bthrift/apache/application_exception.go deleted file mode 100644 index ac02e24cef..0000000000 --- a/pkg/protocol/bthrift/apache/application_exception.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/application_exception.go - -const ( - UNKNOWN_APPLICATION_EXCEPTION = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE_EXCEPTION = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 - INVALID_TRANSFORM = 8 - INVALID_PROTOCOL = 9 - UNSUPPORTED_CLIENT_TYPE = 10 -) - -type TApplicationException = thrift.TApplicationException - -var NewTApplicationException = thrift.NewTApplicationException diff --git a/pkg/protocol/bthrift/apache/binary_protocol.go b/pkg/protocol/bthrift/apache/binary_protocol.go deleted file mode 100644 index 2a2a4538b2..0000000000 --- a/pkg/protocol/bthrift/apache/binary_protocol.go +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/binary_protocol.go - -var NewTBinaryProtocol = thrift.NewTBinaryProtocol diff --git a/pkg/protocol/bthrift/apache/memory_buffer.go b/pkg/protocol/bthrift/apache/memory_buffer.go deleted file mode 100644 index 10a0af751f..0000000000 --- a/pkg/protocol/bthrift/apache/memory_buffer.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import ( - "bytes" - "context" -) - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/memory_buffer.go - -// Memory buffer-based implementation of the TTransport interface. -type TMemoryBuffer struct { - *bytes.Buffer - size int -} - -func NewTMemoryBufferLen(size int) *TMemoryBuffer { - buf := make([]byte, 0, size) - return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} -} - -func (p *TMemoryBuffer) IsOpen() bool { - return true -} - -func (p *TMemoryBuffer) Open() error { - return nil -} - -func (p *TMemoryBuffer) Close() error { - p.Buffer.Reset() - return nil -} - -// Flushing a memory buffer is a no-op -func (p *TMemoryBuffer) Flush(ctx context.Context) error { - return nil -} - -func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) { - return uint64(p.Buffer.Len()) -} diff --git a/pkg/protocol/bthrift/apache/messagetype.go b/pkg/protocol/bthrift/apache/messagetype.go deleted file mode 100644 index 1885144aee..0000000000 --- a/pkg/protocol/bthrift/apache/messagetype.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/messagetype.go - -// Message type constants in the Thrift protocol. -type TMessageType = thrift.TMessageType - -const ( - INVALID_TMESSAGE_TYPE TMessageType = 0 - CALL TMessageType = 1 - REPLY TMessageType = 2 - EXCEPTION TMessageType = 3 - ONEWAY TMessageType = 4 -) diff --git a/pkg/protocol/bthrift/apache/protocol.go b/pkg/protocol/bthrift/apache/protocol.go deleted file mode 100644 index 9d0a991d96..0000000000 --- a/pkg/protocol/bthrift/apache/protocol.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol.go - -const ( - VERSION_MASK = 0xffff0000 - VERSION_1 = 0x80010000 -) - -type TProtocol = thrift.TProtocol - -// The maximum recursive depth the skip() function will traverse -const DEFAULT_RECURSION_DEPTH = 64 - -var SkipDefaultDepth = thrift.SkipDefaultDepth diff --git a/pkg/protocol/bthrift/apache/protocol_exception.go b/pkg/protocol/bthrift/apache/protocol_exception.go deleted file mode 100644 index 7b020797f5..0000000000 --- a/pkg/protocol/bthrift/apache/protocol_exception.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol_exception.go - -var NewTProtocolExceptionWithType = thrift.NewTProtocolExceptionWithType - -const ( - UNKNOWN_PROTOCOL_EXCEPTION = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 - NOT_IMPLEMENTED = 5 - DEPTH_LIMIT = 6 -) diff --git a/pkg/protocol/bthrift/apache/serializer.go b/pkg/protocol/bthrift/apache/serializer.go deleted file mode 100644 index c255250301..0000000000 --- a/pkg/protocol/bthrift/apache/serializer.go +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/serializer.go - -type TStruct interface { - Write(p TProtocol) error - Read(p TProtocol) error -} diff --git a/pkg/protocol/bthrift/apache/thrift.go b/pkg/protocol/bthrift/apache/thrift.go new file mode 100644 index 0000000000..f60a407dc4 --- /dev/null +++ b/pkg/protocol/bthrift/apache/thrift.go @@ -0,0 +1,84 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import "github.com/apache/thrift/lib/go/thrift" + +// this file contains types and funcs mainly used by generated code +// feel free to add or remove as long as it aligns with thriftgo + +type TStruct = thrift.TStruct + +type TProtocol = thrift.TProtocol + +type TTransport = thrift.TTransport + +const ( // from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol.go + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 +) + +var SkipDefaultDepth = thrift.SkipDefaultDepth + +type TException = thrift.TException + +var ( + PrependError = thrift.PrependError + NewTProtocolExceptionWithType = thrift.NewTProtocolExceptionWithType +) + +const ( // from github.com/apache/thrift@v0.13.0/lib/go/thrift/protocol_exception.go + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) + +type TMessageType = thrift.TMessageType + +const ( // from github.com/apache/thrift@v0.13.0/lib/go/thrift/messagetype.go + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) + +type TType = thrift.TType + +const ( // from github.com/apache/thrift@v0.13.0/lib/go/thrift/type.go + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 +) diff --git a/pkg/protocol/bthrift/apache/transport.go b/pkg/protocol/bthrift/apache/transport.go deleted file mode 100644 index 25a752ae52..0000000000 --- a/pkg/protocol/bthrift/apache/transport.go +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/transport.go - -type TTransport = thrift.TTransport diff --git a/pkg/protocol/bthrift/apache/type.go b/pkg/protocol/bthrift/apache/type.go deleted file mode 100644 index 42533b085e..0000000000 --- a/pkg/protocol/bthrift/apache/type.go +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/type.go - -type TType = thrift.TType - -const ( - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 -) diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index 8db034e1dd..4f053d1378 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -19,20 +19,20 @@ package bthrift import ( "encoding/binary" - "errors" "fmt" "math" - "github.com/cloudwego/kitex/pkg/mem" + "github.com/bytedance/gopkg/lang/span" + gthrift "github.com/cloudwego/gopkg/protocol/thrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) var ( // Binary protocol for bthrift. Binary binaryProtocol _ BTProtocol = binaryProtocol{} - spanCache = mem.NewSpanCache(1024 * 1024) + spanCache = span.NewSpanCache(1024 * 1024) spanCacheEnable bool = false ) @@ -263,49 +263,52 @@ func (binaryProtocol) BinaryLengthNocopy(value []byte) int { return l + len(value) } +var ( + errBadVersion = gthrift.NewProtocolException(gthrift.BAD_VERSION, "Bad version in ReadMessageBegin") + errMissingVersion = gthrift.NewProtocolException(gthrift.BAD_VERSION, "Missing version in ReadMessageBegin") + + errInvalidDataLength = gthrift.NewProtocolException(gthrift.INVALID_DATA, "Invalid data length") +) + func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID thrift.TMessageType, seqid int32, length int, err error) { size, l, e := Binary.ReadI32(buf) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } if size > 0 { - err = perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin") + err = errMissingVersion return } typeID = thrift.TMessageType(size & 0x0ff) version := int64(size) & thrift.VERSION_MASK if version != thrift.VERSION_1 { - err = perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin") + err = errBadVersion return } name, l, e = Binary.ReadString(buf[length:]) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } seqid, l, e = Binary.ReadI32(buf[length:]) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } return } -func (binaryProtocol) ReadMessageEnd(buf []byte) (int, error) { - return 0, nil -} +func (binaryProtocol) ReadMessageEnd(_ []byte) (int, error) { return 0, nil } -func (binaryProtocol) ReadStructBegin(buf []byte) (name string, length int, err error) { +func (binaryProtocol) ReadStructBegin(_ []byte) (name string, length int, err error) { return } -func (binaryProtocol) ReadStructEnd(buf []byte) (int, error) { - return 0, nil -} +func (binaryProtocol) ReadStructEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadFieldBegin(buf []byte) (name string, typeID thrift.TType, id int16, length int, err error) { t, l, e := Binary.ReadByte(buf) @@ -322,59 +325,55 @@ func (binaryProtocol) ReadFieldBegin(buf []byte) (name string, typeID thrift.TTy return } -func (binaryProtocol) ReadFieldEnd(buf []byte) (int, error) { - return 0, nil -} +func (binaryProtocol) ReadFieldEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadMapBegin(buf []byte) (keyType, valueType thrift.TType, size, length int, err error) { k, l, e := Binary.ReadByte(buf) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } keyType = thrift.TType(k) v, l, e := Binary.ReadByte(buf[length:]) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } valueType = thrift.TType(v) size32, l, e := Binary.ReadI32(buf[length:]) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } if size32 < 0 { - err = perrors.InvalidDataLength + err = errInvalidDataLength return } size = int(size32) return } -func (binaryProtocol) ReadMapEnd(buf []byte) (int, error) { - return 0, nil -} +func (binaryProtocol) ReadMapEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadListBegin(buf []byte) (elemType thrift.TType, size, length int, err error) { b, l, e := Binary.ReadByte(buf) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } elemType = thrift.TType(b) size32, l, e := Binary.ReadI32(buf[length:]) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } if size32 < 0 { - err = perrors.InvalidDataLength + err = errInvalidDataLength return } size = int(size32) @@ -382,35 +381,31 @@ func (binaryProtocol) ReadListBegin(buf []byte) (elemType thrift.TType, size, le return } -func (binaryProtocol) ReadListEnd(buf []byte) (int, error) { - return 0, nil -} +func (binaryProtocol) ReadListEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadSetBegin(buf []byte) (elemType thrift.TType, size, length int, err error) { b, l, e := Binary.ReadByte(buf) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } elemType = thrift.TType(b) size32, l, e := Binary.ReadI32(buf[length:]) length += l if e != nil { - err = perrors.NewProtocolError(e) + err = e return } if size32 < 0 { - err = perrors.InvalidDataLength + err = errInvalidDataLength return } size = int(size32) return } -func (binaryProtocol) ReadSetEnd(buf []byte) (int, error) { - return 0, nil -} +func (binaryProtocol) ReadSetEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadBool(buf []byte) (value bool, length int, err error) { b, l, e := Binary.ReadByte(buf) @@ -421,45 +416,58 @@ func (binaryProtocol) ReadBool(buf []byte) (value bool, length int, err error) { return v, l, e } +var errReadByte = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadByte] len(buf) < 1") + func (binaryProtocol) ReadByte(buf []byte) (value int8, length int, err error) { if len(buf) < 1 { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadByte] buf length less than 1") + return value, length, errReadByte } return int8(buf[0]), 1, err } +var errReadI16 = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadI16] len(buf) < 2") + func (binaryProtocol) ReadI16(buf []byte) (value int16, length int, err error) { if len(buf) < 2 { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadI16] buf length less than 2") + return value, length, errReadI16 } value = int16(binary.BigEndian.Uint16(buf)) - return value, 2, err + return value, 2, nil } +var errReadI32 = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadI32] len(buf) < 4") + func (binaryProtocol) ReadI32(buf []byte) (value int32, length int, err error) { if len(buf) < 4 { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadI32] buf length less than 4") + return value, length, errReadI32 } value = int32(binary.BigEndian.Uint32(buf)) - return value, 4, err + return value, 4, nil } +var errReadI64 = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadI64] len(buf) < 8") + func (binaryProtocol) ReadI64(buf []byte) (value int64, length int, err error) { if len(buf) < 8 { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadI64] buf length less than 8") + return value, length, errReadI64 } value = int64(binary.BigEndian.Uint64(buf)) - return value, 8, err + return value, 8, nil } +var errReadDouble = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadDouble] len(buf) < 8") + func (binaryProtocol) ReadDouble(buf []byte) (value float64, length int, err error) { if len(buf) < 8 { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadDouble] buf length less than 8") + return value, length, errReadDouble } value = math.Float64frombits(binary.BigEndian.Uint64(buf)) - return value, 8, err + return value, 8, nil } +var errReadString = gthrift.NewProtocolException( + gthrift.INVALID_DATA, "[ReadString] the string size greater than buf length") + func (binaryProtocol) ReadString(buf []byte) (value string, length int, err error) { size, l, e := Binary.ReadI32(buf) length += l @@ -468,7 +476,7 @@ func (binaryProtocol) ReadString(buf []byte) (value string, length int, err erro return } if size < 0 || int(size) > len(buf) { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadString] the string size greater than buf length") + return value, length, errReadString } if spanCacheEnable { data := spanCache.Copy(buf[length : length+int(size)]) @@ -480,6 +488,9 @@ func (binaryProtocol) ReadString(buf []byte) (value string, length int, err erro return } +var errReadBinary = gthrift.NewProtocolException( + gthrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length") + func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err error) { _size, l, e := Binary.ReadI32(buf) length += l @@ -489,7 +500,7 @@ func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err erro } size := int(_size) if size < 0 || size > len(buf) { - return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length") + return value, length, errReadBinary } if spanCacheEnable { value = spanCache.Copy(buf[length : length+size]) @@ -508,13 +519,16 @@ func (binaryProtocol) Skip(buf []byte, fieldType thrift.TType) (length int, err // SkipDefaultDepth skips over the next data element from the provided input TProtocol object. func SkipDefaultDepth(buf []byte, prot BTProtocol, typeID thrift.TType) (int, error) { - return Skip(buf, prot, typeID, thrift.DEFAULT_RECURSION_DEPTH) + const defaultRecursionDepth = 64 // same as thrift.DEFAULT_RECURSION_DEPTH + return Skip(buf, prot, typeID, defaultRecursionDepth) } +var errSkipDepthLimit = gthrift.NewProtocolException(gthrift.DEPTH_LIMIT, "depth limit exceeded") + // Skip skips over the next data element from the provided input TProtocol object. func Skip(buf []byte, self BTProtocol, fieldType thrift.TType, maxDepth int) (length int, err error) { if maxDepth <= 0 { - return 0, thrift.NewTProtocolExceptionWithType(thrift.DEPTH_LIMIT, errors.New("depth limit exceeded")) + return 0, errSkipDepthLimit } var l int @@ -647,6 +661,7 @@ func Skip(buf []byte, self BTProtocol, fieldType thrift.TType, maxDepth int) (le } return default: - return 0, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("unknown data type %d", fieldType)) + return 0, gthrift.NewProtocolException( + gthrift.INVALID_DATA, fmt.Sprintf("unknown data type %d", fieldType)) } } diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index a0754bcd55..ba86bd6a53 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -21,8 +21,8 @@ import ( "fmt" "testing" - "github.com/cloudwego/kitex/internal/test" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/protocol/bthrift/internal/test" ) // TestWriteMessageEnd test binary WriteMessageEnd function diff --git a/internal/mocks/thrift/utils.go b/pkg/protocol/bthrift/compat.go similarity index 95% rename from internal/mocks/thrift/utils.go rename to pkg/protocol/bthrift/compat.go index 5d080bc513..0e9cb58494 100644 --- a/internal/mocks/thrift/utils.go +++ b/pkg/protocol/bthrift/compat.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package thrift +package bthrift import ( "errors" @@ -57,12 +57,12 @@ func (p ApacheCodecAdapter) Read(tp athrift.TProtocol) error { // ToApacheCodec converts a thrift.FastCodec to athrift.TStruct func ToApacheCodec(p thrift.FastCodec) athrift.TStruct { - return ApacheCodecAdapter{p: p} + return &ApacheCodecAdapter{p} } // UnpackApacheCodec unpacks the value returned by `ToApacheCodec` func UnpackApacheCodec(v interface{}) interface{} { - a, ok := v.(ApacheCodecAdapter) + a, ok := v.(*ApacheCodecAdapter) if ok { return a.p } diff --git a/pkg/protocol/bthrift/internal/test/README.md b/pkg/protocol/bthrift/internal/test/README.md new file mode 100644 index 0000000000..4d000c03b6 --- /dev/null +++ b/pkg/protocol/bthrift/internal/test/README.md @@ -0,0 +1,6 @@ +# test + +`test` is copied from `kitex/internal/test` for preparing adding `go.mod` under `bthrift` in the future. +`bthrift` will not import `kitex` other packages, vice versa. + +for more details, see README.md of `bthrift` diff --git a/pkg/protocol/bthrift/internal/test/assert.go b/pkg/protocol/bthrift/internal/test/assert.go new file mode 100644 index 0000000000..71f3fe6d52 --- /dev/null +++ b/pkg/protocol/bthrift/internal/test/assert.go @@ -0,0 +1,92 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import "reflect" + +// testingTB is a subset of common methods between *testing.T and *testing.B. +type testingTB interface { + Fatal(args ...interface{}) + Fatalf(format string, args ...interface{}) + Helper() +} + +// Assert asserts cond is true, otherwise fails the test. +func Assert(t testingTB, cond bool, val ...interface{}) { + if !cond { + t.Helper() + if len(val) > 0 { + val = append([]interface{}{"assertion failed: "}, val...) + t.Fatal(val...) + } else { + t.Fatal("assertion failed") + } + } +} + +// Assertf asserts cond is true, otherwise fails the test. +func Assertf(t testingTB, cond bool, format string, val ...interface{}) { + if !cond { + t.Helper() + t.Fatalf(format, val...) + } +} + +// DeepEqual asserts a and b are deep equal, otherwise fails the test. +func DeepEqual(t testingTB, a, b interface{}) { + if !reflect.DeepEqual(a, b) { + t.Helper() + t.Fatalf("assertion failed: %v != %v", a, b) + } +} + +// Panic asserts fn should panic and recover it, otherwise fails the test. +func Panic(t testingTB, fn func()) { + hasPanic := false + func() { + defer func() { + if err := recover(); err != nil { + hasPanic = true + } + }() + fn() + }() + if !hasPanic { + t.Helper() + t.Fatal("assertion failed: did not panic") + } +} + +// PanicAt asserts fn should panic and recover it, otherwise fails the test. The expect function can be provided to do further examination of the error. +func PanicAt(t testingTB, fn func(), expect func(err interface{}) bool) { + var err interface{} + func() { + defer func() { + err = recover() + }() + fn() + }() + if err == nil { + t.Helper() + t.Fatal("assertion failed: did not panic") + return + } + if expect != nil && !expect(err) { + t.Helper() + t.Fatal("assertion failed: panic but not expected") + } +} diff --git a/pkg/protocol/bthrift/internal/test/assert_test.go b/pkg/protocol/bthrift/internal/test/assert_test.go new file mode 100644 index 0000000000..f60b3f1664 --- /dev/null +++ b/pkg/protocol/bthrift/internal/test/assert_test.go @@ -0,0 +1,105 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "fmt" + "testing" +) + +type mockTesting struct { + t *testing.T + + expect0 string + expect1 string + + helper bool +} + +func (m *mockTesting) Reset() { + m.expect0 = "" + m.expect1 = "" + m.helper = false +} + +func (m *mockTesting) ExpectFatal(args ...interface{}) { + m.expect0 = fmt.Sprint(args...) +} + +func (m *mockTesting) ExpectFatalf(format string, args ...interface{}) { + m.expect1 = fmt.Sprintf(format, args...) +} + +func (m *mockTesting) Fatal(args ...interface{}) { + t := m.t + t.Helper() + if !m.helper { + t.Fatal("need to call Helper before calling Fatal") + } + if s := fmt.Sprint(args...); s != m.expect0 { + t.Fatalf("got %q expect %q", s, m.expect0) + } +} + +func (m *mockTesting) Fatalf(format string, args ...interface{}) { + t := m.t + t.Helper() + if !m.helper { + t.Fatal("need to call Helper before calling Fatalf") + } + if s := fmt.Sprintf(format, args...); s != m.expect1 { + t.Fatalf("got %q expect %q", s, m.expect1) + } +} + +func (m *mockTesting) Helper() { m.helper = true } + +func TestAssert(t *testing.T) { + m := &mockTesting{t: t} + + m.Reset() + m.ExpectFatal("assertion failed") + Assert(m, false) + + m.Reset() + m.ExpectFatal("assertion failed: hello") + Assert(m, false, "hello") + + m.Reset() + m.ExpectFatalf("assert: %s", "hello") + Assertf(m, false, "assert: %s", "hello") + + m.Reset() + m.ExpectFatalf("assertion failed: 1 != 2") + DeepEqual(m, 1, 2) + + m.Reset() + m.ExpectFatal("") + Panic(m, func() { panic("hello") }) + + m.Reset() + m.ExpectFatal("assertion failed: did not panic") + Panic(m, func() {}) + + m.Reset() + m.ExpectFatal("assertion failed: did not panic") + PanicAt(m, func() {}, func(err interface{}) bool { return true }) + + m.Reset() + m.ExpectFatal("assertion failed: panic but not expected") + PanicAt(m, func() { panic("hello") }, func(err interface{}) bool { return false }) +} diff --git a/pkg/protocol/bthrift/unknown_test.go b/pkg/protocol/bthrift/unknown_test.go index 2325748a0c..032165d381 100644 --- a/pkg/protocol/bthrift/unknown_test.go +++ b/pkg/protocol/bthrift/unknown_test.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/protocol/bthrift/internal/test" ) func TestUnknownFieldTypeConvert(t *testing.T) { diff --git a/pkg/remote/codec/thrift/binary_protocol.go b/pkg/remote/codec/thrift/deprecated.go similarity index 79% rename from pkg/remote/codec/thrift/binary_protocol.go rename to pkg/remote/codec/thrift/deprecated.go index b2775b1c2f..f491c8899c 100644 --- a/pkg/remote/codec/thrift/binary_protocol.go +++ b/pkg/remote/codec/thrift/deprecated.go @@ -22,37 +22,52 @@ import ( "math" "sync" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) -// must be strict read & strict write -var ( - bpPool sync.Pool - _ thrift.TProtocol = (*BinaryProtocol)(nil) -) +// MessageReader read from athrift.TProtocol +// Deprecated: use github.com/apache/thrift/lib/go/thrift.TStruct +type MessageReader interface { + Read(iprot athrift.TProtocol) error +} -func init() { - bpPool.New = newBP +// MessageWriter write to athrift.TProtocol +// Deprecated: use github.com/apache/thrift/lib/go/thrift.TStruct +type MessageWriter interface { + Write(oprot athrift.TProtocol) error } -func newBP() interface{} { - return &BinaryProtocol{} +// UnmarshalThriftException decode thrift exception from tProt +// TODO: this func should be removed in the future. it's exposed accidentally. +// Deprecated: Use `SkipDecoder` + `ApplicationException` of `cloudwego/gopkg/protocol/thrift` instead. +func UnmarshalThriftException(tProt athrift.TProtocol) error { + return unmarshalThriftException(tProt.Transport()) } -// NewBinaryProtocol ... -func NewBinaryProtocol(t remote.ByteBuffer) *BinaryProtocol { - bp := bpPool.Get().(*BinaryProtocol) - bp.trans = t - return bp +var bpPool = sync.Pool{ + New: func() interface{} { + return &BinaryProtocol{} + }, } // BinaryProtocol ... +// Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol type BinaryProtocol struct { trans remote.ByteBuffer } +var _ athrift.TProtocol = (*BinaryProtocol)(nil) + +// NewBinaryProtocol ... +// Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol +func NewBinaryProtocol(t remote.ByteBuffer) *BinaryProtocol { + bp := bpPool.Get().(*BinaryProtocol) + bp.trans = t + return bp +} + // Recycle ... func (p *BinaryProtocol) Recycle() { p.trans = nil @@ -64,8 +79,8 @@ func (p *BinaryProtocol) Recycle() { */ // WriteMessageBegin ... -func (p *BinaryProtocol) WriteMessageBegin(name string, typeID thrift.TMessageType, seqID int32) error { - version := uint32(thrift.VERSION_1) | uint32(typeID) +func (p *BinaryProtocol) WriteMessageBegin(name string, typeID athrift.TMessageType, seqID int32) error { + version := uint32(athrift.VERSION_1) | uint32(typeID) e := p.WriteI32(int32(version)) if e != nil { return e @@ -94,7 +109,7 @@ func (p *BinaryProtocol) WriteStructEnd() error { } // WriteFieldBegin ... -func (p *BinaryProtocol) WriteFieldBegin(name string, typeID thrift.TType, id int16) error { +func (p *BinaryProtocol) WriteFieldBegin(name string, typeID athrift.TType, id int16) error { e := p.WriteByte(int8(typeID)) if e != nil { return e @@ -110,12 +125,12 @@ func (p *BinaryProtocol) WriteFieldEnd() error { // WriteFieldStop ... func (p *BinaryProtocol) WriteFieldStop() error { - e := p.WriteByte(thrift.STOP) + e := p.WriteByte(athrift.STOP) return e } // WriteMapBegin ... -func (p *BinaryProtocol) WriteMapBegin(keyType, valueType thrift.TType, size int) error { +func (p *BinaryProtocol) WriteMapBegin(keyType, valueType athrift.TType, size int) error { e := p.WriteByte(int8(keyType)) if e != nil { return e @@ -134,7 +149,7 @@ func (p *BinaryProtocol) WriteMapEnd() error { } // WriteListBegin ... -func (p *BinaryProtocol) WriteListBegin(elemType thrift.TType, size int) error { +func (p *BinaryProtocol) WriteListBegin(elemType athrift.TType, size int) error { e := p.WriteByte(int8(elemType)) if e != nil { return e @@ -149,7 +164,7 @@ func (p *BinaryProtocol) WriteListEnd() error { } // WriteSetBegin ... -func (p *BinaryProtocol) WriteSetBegin(elemType thrift.TType, size int) error { +func (p *BinaryProtocol) WriteSetBegin(elemType athrift.TType, size int) error { e := p.WriteByte(int8(elemType)) if e != nil { return e @@ -251,7 +266,7 @@ func (p *BinaryProtocol) malloc(size int) ([]byte, error) { */ // ReadMessageBegin ... -func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID thrift.TMessageType, seqID int32, err error) { +func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID athrift.TMessageType, seqID int32, err error) { size, e := p.ReadI32() if e != nil { return "", typeID, 0, perrors.NewProtocolError(e) @@ -259,9 +274,9 @@ func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID thrift.TMessage if size > 0 { return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin") } - typeID = thrift.TMessageType(size & 0x0ff) - version := int64(int64(size) & thrift.VERSION_MASK) - if version != thrift.VERSION_1 { + typeID = athrift.TMessageType(size & 0x0ff) + version := int64(int64(size) & athrift.VERSION_MASK) + if version != athrift.VERSION_1 { return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin") } name, e = p.ReadString() @@ -291,13 +306,13 @@ func (p *BinaryProtocol) ReadStructEnd() error { } // ReadFieldBegin ... -func (p *BinaryProtocol) ReadFieldBegin() (name string, typeID thrift.TType, id int16, err error) { +func (p *BinaryProtocol) ReadFieldBegin() (name string, typeID athrift.TType, id int16, err error) { t, err := p.ReadByte() - typeID = thrift.TType(t) + typeID = athrift.TType(t) if err != nil { return name, typeID, id, err } - if t != thrift.STOP { + if t != athrift.STOP { id, err = p.ReadI16() } return name, typeID, id, err @@ -309,19 +324,19 @@ func (p *BinaryProtocol) ReadFieldEnd() error { } // ReadMapBegin ... -func (p *BinaryProtocol) ReadMapBegin() (kType, vType thrift.TType, size int, err error) { +func (p *BinaryProtocol) ReadMapBegin() (kType, vType athrift.TType, size int, err error) { k, e := p.ReadByte() if e != nil { err = perrors.NewProtocolError(e) return } - kType = thrift.TType(k) + kType = athrift.TType(k) v, e := p.ReadByte() if e != nil { err = perrors.NewProtocolError(e) return } - vType = thrift.TType(v) + vType = athrift.TType(v) size32, e := p.ReadI32() if e != nil { err = perrors.NewProtocolError(e) @@ -341,13 +356,13 @@ func (p *BinaryProtocol) ReadMapEnd() error { } // ReadListBegin ... -func (p *BinaryProtocol) ReadListBegin() (elemType thrift.TType, size int, err error) { +func (p *BinaryProtocol) ReadListBegin() (elemType athrift.TType, size int, err error) { b, e := p.ReadByte() if e != nil { err = perrors.NewProtocolError(e) return } - elemType = thrift.TType(b) + elemType = athrift.TType(b) size32, e := p.ReadI32() if e != nil { err = perrors.NewProtocolError(e) @@ -368,13 +383,13 @@ func (p *BinaryProtocol) ReadListEnd() error { } // ReadSetBegin ... -func (p *BinaryProtocol) ReadSetBegin() (elemType thrift.TType, size int, err error) { +func (p *BinaryProtocol) ReadSetBegin() (elemType athrift.TType, size int, err error) { b, e := p.ReadByte() if e != nil { err = perrors.NewProtocolError(e) return } - elemType = thrift.TType(b) + elemType = athrift.TType(b) size32, e := p.ReadI32() if e != nil { err = perrors.NewProtocolError(e) @@ -491,13 +506,13 @@ func (p *BinaryProtocol) Flush(ctx context.Context) (err error) { } // Skip ... -func (p *BinaryProtocol) Skip(fieldType thrift.TType) (err error) { - return thrift.SkipDefaultDepth(p, fieldType) +func (p *BinaryProtocol) Skip(fieldType athrift.TType) (err error) { + return athrift.SkipDefaultDepth(p, fieldType) } // ttransportByteBuffer ... // for exposing remote.ByteBuffer via p.Transport(), -// mainly for testing purpose, see internal/mocks/thrift/utils.go +// mainly for testing purpose, see internal/mocks/athrift/utils.go type ttransportByteBuffer struct { remote.ByteBuffer } @@ -509,7 +524,7 @@ func (ttransportByteBuffer) Open() error { panic("not func (p ttransportByteBuffer) RemainingBytes() uint64 { return uint64(p.ReadableLen()) } // Transport ... -func (p *BinaryProtocol) Transport() thrift.TTransport { +func (p *BinaryProtocol) Transport() athrift.TTransport { return ttransportByteBuffer{p.trans} } diff --git a/pkg/remote/codec/thrift/binary_protocol_test.go b/pkg/remote/codec/thrift/deprecated_test.go similarity index 100% rename from pkg/remote/codec/thrift/binary_protocol_test.go rename to pkg/remote/codec/thrift/deprecated_test.go diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 829767cd3b..e7a9fed13f 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -23,8 +23,8 @@ import ( "io" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/apache" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" @@ -199,22 +199,21 @@ func encodeGenericThrift(out remote.ByteBuffer, ctx context.Context, method stri return nil } -// encodeBasicThrift encode with the old athrift way (slow) +// encodeBasicThrift encode with the old apache thrift way (slow) func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error { if err := verifyMarshalBasicThriftDataType(data); err != nil { return err } - tProt := NewBinaryProtocol(out) - if err := tProt.WriteMessageBegin(method, athrift.TMessageType(msgType), seqID); err != nil { - return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error())) - } - if err := marshalBasicThriftData(tProt, data); err != nil { + + b, err := out.Malloc(thrift.Binary.MessageBeginLength(method)) + if err != nil { return err } - if err := tProt.WriteMessageEnd(); err != nil { - return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageEnd failed: %s", err.Error())) + _ = thrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seqID) + + if err := apache.ThriftWrite(apache.NewDefaultTransport(out), data); err != nil { + return err } - tProt.Recycle() return nil } @@ -222,8 +221,10 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { // TODO(xiaost): Refactor the code after v0.11.0 is released. Unifying checking and fallback logic. - tProt := NewBinaryProtocol(in) - methodName, msgType, seqID, err := tProt.ReadMessageBegin() + br := thrift.NewBinaryReader(in) + defer br.Release() + + methodName, msgType, seqID, err := br.ReadMessageBegin() if err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal, ReadMessageBegin failed: %s", err.Error())) } @@ -233,7 +234,7 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r // exception message if message.MessageType() == remote.Exception { - return UnmarshalThriftException(tProt) + return unmarshalThriftException(in) } if err = validateMessageBeforeDecode(message, seqID, methodName); err != nil { @@ -257,16 +258,12 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r err = remote.NewTransError(remote.ProtocolError, err) } } else { - err = c.unmarshalThriftData(tProt, data, dataLen) + err = c.unmarshalThriftData(in, data, dataLen) } rpcinfo.Record(ctx, ri, stats.WaitReadFinish, err) if err != nil { return err } - if err = tProt.ReadMessageEnd(); err != nil { - return remote.NewTransError(remote.ProtocolError, err) - } - tProt.Recycle() return err } @@ -293,16 +290,6 @@ func (c thriftCodec) Name() string { return serviceinfo.Thrift.String() } -// MessageWriter write to athrift.TProtocol -type MessageWriter interface { - Write(oprot athrift.TProtocol) error -} - -// MessageReader read from athrift.TProtocol -type MessageReader interface { - Read(oprot athrift.TProtocol) error -} - type genericWriter interface { // used by pkg/generic Write(ctx context.Context, method string, w io.Writer) error } diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index db0ccc3503..b6243107bd 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -17,13 +17,15 @@ package thrift import ( + "bytes" "context" "fmt" + "io" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/apache" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) @@ -68,47 +70,24 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ return nil, err } - // TODO(xiaost): Deprecate the code by using cloudwebgo/gopkg in v0.12.0 // fallback to old thrift way (slow) - transport := athrift.NewTMemoryBufferLen(marshalThriftBufferSize) - tProt := athrift.NewTBinaryProtocol(transport, true, true) - if err := marshalBasicThriftData(tProt, data); err != nil { + buf := bytes.NewBuffer(make([]byte, 0, marshalThriftBufferSize)) + if err := apache.ThriftWrite(apache.NewBufferTransport(buf), data); err != nil { return nil, err } - return transport.Bytes(), nil + return buf.Bytes(), nil } // verifyMarshalBasicThriftDataType verifies whether data could be marshaled by old thrift way func verifyMarshalBasicThriftDataType(data interface{}) error { - switch data.(type) { - case MessageWriter: - default: + if err := apache.CheckThriftWrite(data); err != nil { return errEncodeMismatchMsgType } return nil } -// marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) -// It uses the old thrift way which is much slower than FastCodec and Frugal -func marshalBasicThriftData(tProt athrift.TProtocol, data interface{}) error { - var err error - switch msg := data.(type) { - case MessageWriter: - err = msg.Write(tProt) - default: - return errEncodeMismatchMsgType - } - if err != nil { - return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) - } - return nil -} - -// UnmarshalThriftException decode thrift exception from tProt -// TODO: this func should be removed in the future. it's exposed accidentally. -// Deprecated: Use `SkipDecoder` + `ApplicationException` of `cloudwego/gopkg/protocol/thrift` instead. -func UnmarshalThriftException(tProt athrift.TProtocol) error { - d := thrift.NewSkipDecoder(tProt.Transport()) +func unmarshalThriftException(in io.Reader) error { + d := thrift.NewSkipDecoder(in) defer d.Release() b, err := d.Next(thrift.STRUCT) if err != nil { @@ -129,12 +108,9 @@ func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method if !ok { c = defaultCodec } - tProt := NewBinaryProtocol(remote.NewReaderBuffer(buf)) - err := c.unmarshalThriftData(tProt, data, len(buf)) - if err == nil { - tProt.Recycle() - } - return err + trans := remote.NewReaderBuffer(buf) + defer trans.Release(nil) + return c.unmarshalThriftData(trans, data, len(buf)) } func (c thriftCodec) fastMessageUnmarshalAvailable(data interface{}, payloadLen int) bool { @@ -145,10 +121,10 @@ func (c thriftCodec) fastMessageUnmarshalAvailable(data interface{}, payloadLen return ok } -func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { +func (c thriftCodec) fastUnmarshal(trans remote.ByteBuffer, data interface{}, dataLen int) error { msg := data.(thrift.FastCodec) if dataLen > 0 { - buf, err := tProt.next(dataLen) + buf, err := trans.Next(dataLen) if err != nil { return remote.NewTransError(remote.ProtocolError, err) } @@ -158,7 +134,7 @@ func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, data } return nil } - buf, err := getSkippedStructBuffer(tProt) + buf, err := getSkippedStructBuffer(trans) if err != nil { return err } @@ -171,15 +147,15 @@ func (c thriftCodec) fastUnmarshal(tProt *BinaryProtocol, data interface{}, data // unmarshalThriftData only decodes the data (after methodName, msgType and seqId) // method is only used for generic calls -func (c thriftCodec) unmarshalThriftData(tProt *BinaryProtocol, data interface{}, dataLen int) error { +func (c thriftCodec) unmarshalThriftData(trans remote.ByteBuffer, data interface{}, dataLen int) error { // decode with hyper unmarshal if c.IsSet(FrugalRead) && c.hyperMessageUnmarshalAvailable(data, dataLen) { - return c.hyperUnmarshal(tProt, data, dataLen) + return c.hyperUnmarshal(trans, data, dataLen) } // decode with FastRead if c.IsSet(FastRead) && c.fastMessageUnmarshalAvailable(data, dataLen) { - return c.fastUnmarshal(tProt, data, dataLen) + return c.fastUnmarshal(trans, data, dataLen) } if err := verifyUnmarshalBasicThriftDataType(data); err != nil { @@ -187,22 +163,22 @@ func (c thriftCodec) unmarshalThriftData(tProt *BinaryProtocol, data interface{} if c.CodecType != Basic { // try FrugalRead < - > FastRead fallback if c.fastMessageUnmarshalAvailable(data, dataLen) { - return c.fastUnmarshal(tProt, data, dataLen) + return c.fastUnmarshal(trans, data, dataLen) } if c.hyperMessageUnmarshalAvailable(data, dataLen) { // slim template? - return c.hyperUnmarshal(tProt, data, dataLen) + return c.hyperUnmarshal(trans, data, dataLen) } } return err } // fallback to old thrift way (slow) - return decodeBasicThriftData(tProt, data) + return decodeBasicThriftData(trans, data) } -func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error { +func (c thriftCodec) hyperUnmarshal(trans remote.ByteBuffer, data interface{}, dataLen int) error { if dataLen > 0 { - buf, err := tProt.next(dataLen) + buf, err := trans.Next(dataLen) if err != nil { return remote.NewTransError(remote.ProtocolError, err) } @@ -211,7 +187,7 @@ func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dat } return nil } - buf, err := getSkippedStructBuffer(tProt) + buf, err := getSkippedStructBuffer(trans) if err != nil { return err } @@ -224,31 +200,22 @@ func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dat // verifyUnmarshalBasicThriftDataType verifies whether data could be unmarshal by old thrift way func verifyUnmarshalBasicThriftDataType(data interface{}) error { - switch data.(type) { - case MessageReader: - default: + if err := apache.CheckThriftRead(data); err != nil { return errDecodeMismatchMsgType } return nil } // decodeBasicThriftData decode thrift body the old way (slow) -func decodeBasicThriftData(tProt athrift.TProtocol, data interface{}) error { - var err error - switch t := data.(type) { - case MessageReader: - err = t.Read(tProt) - default: - return errDecodeMismatchMsgType - } - if err != nil { - return remote.NewTransError(remote.ProtocolError, err) +func decodeBasicThriftData(trans remote.ByteBuffer, data interface{}) error { + if err := verifyUnmarshalBasicThriftDataType(data); err != nil { + return err } - return nil + return apache.ThriftRead(apache.NewDefaultTransport(trans), data) } -func getSkippedStructBuffer(tProt *BinaryProtocol) ([]byte, error) { - sd := thrift.NewSkipDecoder(tProt.trans) +func getSkippedStructBuffer(trans remote.ByteBuffer) ([]byte, error) { + sd := thrift.NewSkipDecoder(trans) buf, err := sd.Next(thrift.STRUCT) if err != nil { return nil, remote.NewTransError(remote.ProtocolError, err).AppendMessage("caught in SkipDecoder Next phase") diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index 9f23735cf6..ca275936b8 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -26,7 +26,7 @@ import ( mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/remote" ) @@ -42,21 +42,6 @@ var ( } ) -func TestMarshalBasicThriftData(t *testing.T) { - t.Run("invalid-data", func(t *testing.T) { - err := marshalBasicThriftData(nil, 0) - test.Assert(t, err == errEncodeMismatchMsgType, err) - }) - t.Run("valid-data", func(t *testing.T) { - transport := athrift.NewTMemoryBufferLen(1024) - tProt := athrift.NewTBinaryProtocol(transport, true, true) - err := marshalBasicThriftData(tProt, mocks.ToApacheCodec(mockReq)) - test.Assert(t, err == nil, err) - result := transport.Bytes() - test.Assert(t, reflect.DeepEqual(result, mockReqThrift), result) - }) -} - func TestMarshalThriftData(t *testing.T) { t.Run("NoCodec(=FastCodec)", func(t *testing.T) { buf, err := MarshalThriftData(context.Background(), nil, mockReq) @@ -69,7 +54,7 @@ func TestMarshalThriftData(t *testing.T) { test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) }) t.Run("BasicCodec", func(t *testing.T) { - buf, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), mocks.ToApacheCodec(mockReq)) + buf, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), bthrift.ToApacheCodec(mockReq)) test.Assert(t, err == nil, err) test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) }) @@ -79,25 +64,26 @@ func TestMarshalThriftData(t *testing.T) { func Test_decodeBasicThriftData(t *testing.T) { t.Run("empty-input", func(t *testing.T) { req := &mocks.MockReq{} - tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{})) - err := decodeBasicThriftData(tProt, mocks.ToApacheCodec(req)) + trans := remote.NewReaderBuffer([]byte{}) + err := decodeBasicThriftData(trans, bthrift.ToApacheCodec(req)) test.Assert(t, err != nil, err) }) t.Run("invalid-input", func(t *testing.T) { req := &mocks.MockReq{} - tProt := NewBinaryProtocol(remote.NewReaderBuffer([]byte{0xff})) - err := decodeBasicThriftData(tProt, mocks.ToApacheCodec(req)) + trans := remote.NewReaderBuffer([]byte{0xff}) + err := decodeBasicThriftData(trans, bthrift.ToApacheCodec(req)) test.Assert(t, err != nil, err) }) t.Run("normal-input", func(t *testing.T) { req := &mocks.MockReq{} - tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) - err := decodeBasicThriftData(tProt, mocks.ToApacheCodec(req)) + trans := remote.NewReaderBuffer(mockReqThrift) + err := decodeBasicThriftData(trans, bthrift.ToApacheCodec(req)) checkDecodeResult(t, err, req) }) } func checkDecodeResult(t *testing.T, err error, req *mocks.MockReq) { + t.Helper() test.Assert(t, err == nil, err) test.Assert(t, req.Msg == mockReq.Msg, req.Msg, mockReq.Msg) test.Assert(t, len(req.StrMap) == 0, req.StrMap) @@ -117,7 +103,7 @@ func TestUnmarshalThriftData(t *testing.T) { }) t.Run("BasicCodec", func(t *testing.T) { req := &mocks.MockReq{} - err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, mocks.ToApacheCodec(req)) + err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, bthrift.ToApacheCodec(req)) checkDecodeResult(t, err, req) }) // FrugalCodec: in thrift_frugal_amd64_test.go: TestUnmarshalThriftDataFrugal @@ -127,10 +113,9 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { t.Run("FastCodec with SkipDecoder enabled", func(t *testing.T) { req := &mocks.MockReq{} codec := &thriftCodec{FastRead | EnableSkipDecoder} - tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) - defer tProt.Recycle() + trans := remote.NewReaderBuffer(mockReqThrift) // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(tProt, req, 0) + err := codec.unmarshalThriftData(trans, req, 0) checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, @@ -151,10 +136,9 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { 15 /* list */, 0, 3 /* id=3 */, 6 /* item:I16 */, 0, 0, 0, 1 /* length=1 */, 0, 1, /* I16=1 */ 0, /* end of struct */ } - tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultMockReqThrift)) - defer tProt.Recycle() + trans := remote.NewReaderBuffer(faultMockReqThrift) // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(tProt, req, 0) + err := codec.unmarshalThriftData(trans, req, 0) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in FastCodec using SkipDecoder Buffer")) }) @@ -173,7 +157,7 @@ func TestUnmarshalThriftException(t *testing.T) { err := UnmarshalThriftException(tProtRead) transErr, ok := err.(*remote.TransError) test.Assert(t, ok, err) - test.Assert(t, transErr.TypeID() == athrift.INVALID_PROTOCOL, transErr) + test.Assert(t, transErr.TypeID() == thrift.INVALID_PROTOCOL, transErr) test.Assert(t, transErr.Error() == errMessage, transErr) } @@ -182,8 +166,8 @@ func Test_getSkippedStructBuffer(t *testing.T) { faultThrift := []byte{ 11 /* string */, 0, 1 /* id=1 */, 0, 0, 0, 6 /* length=6 */, 104, 101, 108, 108, 111, /* "hello" */ } - tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultThrift)) - _, err := getSkippedStructBuffer(tProt) + trans := remote.NewReaderBuffer(faultThrift) + _, err := getSkippedStructBuffer(trans) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in SkipDecoder Next phase")) } diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index 3985268171..0393b5e5be 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -228,10 +228,9 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { t.Run("Frugal with SkipDecoder enabled", func(t *testing.T) { req := &MockFrugalTagReq{} codec := &thriftCodec{FrugalRead | EnableSkipDecoder} - tProt := NewBinaryProtocol(remote.NewReaderBuffer(mockReqThrift)) - defer tProt.Recycle() + trans := remote.NewReaderBuffer(mockReqThrift) // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(tProt, req, 0) + err := codec.unmarshalThriftData(trans, req, 0) checkDecodeResult(t, err, &mocks.MockReq{ Msg: req.Msg, StrList: req.StrList, @@ -252,10 +251,9 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { 15 /* list */, 0, 3 /* id=3 */, 6 /* item:I16 */, 0, 0, 0, 1 /* length=1 */, 0, 1, /* I16=1 */ 0, /* end of struct */ } - tProt := NewBinaryProtocol(remote.NewReaderBuffer(faultMockReqThrift)) - defer tProt.Recycle() + trans := remote.NewReaderBuffer(faultMockReqThrift) // specify dataLen with 0 so that skipDecoder works - err := codec.unmarshalThriftData(tProt, req, 0) + err := codec.unmarshalThriftData(trans, req, 0) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in Frugal using SkipDecoder Buffer")) }) diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 6e7a18b5e1..12308545fa 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/kitex/internal/mocks" mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/remote" netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -209,8 +210,8 @@ func BenchmarkNormalParallel(b *testing.B) { test.Assert(b, err == nil, err) // compare Req Arg - sendReq := mt.UnpackApacheCodec(sendMsg.Data()).(*mt.MockTestArgs).Req - recvReq := mt.UnpackApacheCodec(recvMsg.Data()).(*mt.MockTestArgs).Req + sendReq := sendMsg.Data().(*mt.MockTestArgs).Req + recvReq := recvMsg.Data().(*mt.MockTestArgs).Req test.Assert(b, sendReq.Msg == recvReq.Msg) test.Assert(b, len(sendReq.StrList) == len(recvReq.StrList)) test.Assert(b, len(sendReq.StrMap) == len(recvReq.StrMap)) @@ -324,7 +325,7 @@ func TestSkipDecoder(t *testing.T) { func toApacheCodec(v bool, data thrift.FastCodec) interface{} { if v { - return mt.ToApacheCodec(data) + return bthrift.ToApacheCodec(data) } return data } @@ -349,8 +350,8 @@ func initRecvMsg(basic bool) remote.Message { } func compare(t *testing.T, sendMsg, recvMsg remote.Message) { - sendReq := mt.UnpackApacheCodec(sendMsg.Data()).(*mt.MockTestArgs).Req - recvReq := mt.UnpackApacheCodec(recvMsg.Data()).(*mt.MockTestArgs).Req + sendReq := bthrift.UnpackApacheCodec(sendMsg.Data()).(*mt.MockTestArgs).Req + recvReq := bthrift.UnpackApacheCodec(recvMsg.Data()).(*mt.MockTestArgs).Req test.Assert(t, sendReq.Msg == recvReq.Msg) test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList)) test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap)) diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index e8256337e3..806e805f8a 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -17,111 +17,84 @@ package utils import ( + "bytes" "errors" "fmt" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/thrift/apache" athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) // ThriftMessageCodec is used to codec thrift messages. -type ThriftMessageCodec struct { - tb *athrift.TMemoryBuffer - tProt athrift.TProtocol -} +type ThriftMessageCodec struct{} // NewThriftMessageCodec creates a new ThriftMessageCodec. func NewThriftMessageCodec() *ThriftMessageCodec { - // TODO: use remote.ByteBuffer & remote/codec/thrift.BinaryProtocol - transport := athrift.NewTMemoryBufferLen(1024) - tProt := athrift.NewTBinaryProtocol(transport, true, true) - - return &ThriftMessageCodec{ - tb: transport, - tProt: tProt, - } + return &ThriftMessageCodec{} } // Encode do thrift message encode. // Notice! msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result // Notice! seqID will be reset in kitex if the buffer is used for generic call in client side, set seqID=0 is suggested // when you call this method as client. -func (t *ThriftMessageCodec) Encode(method string, msgType athrift.TMessageType, seqID int32, msg athrift.TStruct) (b []byte, err error) { +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift.MarshalFastMsg +func (t *ThriftMessageCodec) Encode(method string, msgType athrift.TMessageType, seqID int32, msg athrift.TStruct) ([]byte, error) { if method == "" { return nil, errors.New("empty methodName in thrift RPCEncode") } - t.tb.Reset() - if err = t.tProt.WriteMessageBegin(method, msgType, seqID); err != nil { - return - } - if err = msg.Write(t.tProt); err != nil { - return + b := make([]byte, thrift.Binary.MessageBeginLength(method)) + _ = thrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seqID) + buf := &bytes.Buffer{} + buf.Write(b) + if err := apache.ThriftWrite(apache.NewBufferTransport(buf), msg); err != nil { + return nil, err } - if err = t.tProt.WriteMessageEnd(); err != nil { - return - } - b = append(b, t.tb.Bytes()...) - return + return buf.Bytes(), nil } // Decode do thrift message decode, notice: msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift.UnmarshalFastMsg func (t *ThriftMessageCodec) Decode(b []byte, msg athrift.TStruct) (method string, seqID int32, err error) { - t.tb.Reset() - if _, err = t.tb.Write(b); err != nil { - return - } - var msgType athrift.TMessageType - if method, msgType, seqID, err = t.tProt.ReadMessageBegin(); err != nil { + var l int + var msgType thrift.TMessageType + method, msgType, seqID, l, err = thrift.Binary.ReadMessageBegin(b) + if err != nil { return } - if msgType == athrift.EXCEPTION { - b = b[thrift.Binary.MessageBeginLength(method):] // for reusing fast read - ex := thrift.NewApplicationException(athrift.UNKNOWN_APPLICATION_EXCEPTION, "") - if _, err = ex.FastRead(b); err != nil { - return - } - if err = t.tProt.ReadMessageEnd(); err != nil { - return + b = b[l:] + if msgType == thrift.EXCEPTION { + ex := thrift.NewApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") + if _, err = ex.FastRead(b); err == nil { + err = ex // ApplicationException as err } - err = ex - return - } - if err = msg.Read(t.tProt); err != nil { return } - t.tProt.ReadMessageEnd() + err = apache.ThriftRead(apache.NewBufferTransport(bytes.NewBuffer(b)), msg) return } // Serialize serialize message into bytes. This is normal thrift serialize func. // Notice: Binary generic use Encode instead of Serialize. -func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) (b []byte, err error) { - t.tb.Reset() - - if err = msg.Write(t.tProt); err != nil { - return +func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) ([]byte, error) { + buf := &bytes.Buffer{} + if err := apache.ThriftWrite(apache.NewBufferTransport(buf), msg); err != nil { + return nil, err } - b = append(b, t.tb.Bytes()...) - return + return buf.Bytes(), nil } // Deserialize deserialize bytes into message. This is normal thrift deserialize func. // Notice: Binary generic use Decode instead of Deserialize. func (t *ThriftMessageCodec) Deserialize(msg athrift.TStruct, b []byte) (err error) { - t.tb.Reset() - if _, err = t.tb.Write(b); err != nil { - return - } - if err = msg.Read(t.tProt); err != nil { - return - } - return nil + buf := bytes.NewBuffer(b) + return apache.ThriftRead(apache.NewBufferTransport(buf), msg) } // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. func MarshalError(method string, err error) []byte { - ex := thrift.NewApplicationException(athrift.INTERNAL_ERROR, err.Error()) + ex := thrift.NewApplicationException(thrift.INTERNAL_ERROR, err.Error()) n := thrift.Binary.MessageBeginLength(method) n += ex.BLength() b := make([]byte, n) @@ -144,7 +117,7 @@ func UnmarshalError(b []byte) error { } // Read Ex body off := l - ex := thrift.NewApplicationException(athrift.INTERNAL_ERROR, "") + ex := thrift.NewApplicationException(thrift.INTERNAL_ERROR, "") if _, err := ex.FastRead(b[off:]); err != nil { return err } diff --git a/pkg/utils/thrift_test.go b/pkg/utils/thrift_test.go index 18d7632b9b..072fb1bd61 100644 --- a/pkg/utils/thrift_test.go +++ b/pkg/utils/thrift_test.go @@ -22,7 +22,8 @@ import ( mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) func TestRPCCodec(t *testing.T) { @@ -39,12 +40,12 @@ func TestRPCCodec(t *testing.T) { args1.Req = req1 // encode - buf, err := rc.Encode("mockMethod", thrift.CALL, 100, mt.ToApacheCodec(args1)) + buf, err := rc.Encode("mockMethod", athrift.CALL, 100, bthrift.ToApacheCodec(args1)) test.Assert(t, err == nil, err) var argsDecode1 mt.MockTestArgs // decode - method, seqID, err := rc.Decode(buf, mt.ToApacheCodec(&argsDecode1)) + method, seqID, err := rc.Decode(buf, bthrift.ToApacheCodec(&argsDecode1)) test.Assert(t, err == nil) test.Assert(t, method == "mockMethod") @@ -65,12 +66,12 @@ func TestRPCCodec(t *testing.T) { args2 := mt.NewMockTestArgs() args2.Req = req2 // encode - buf, err = rc.Encode("mockMethod1", thrift.CALL, 101, mt.ToApacheCodec(args2)) + buf, err = rc.Encode("mockMethod1", athrift.CALL, 101, bthrift.ToApacheCodec(args2)) test.Assert(t, err == nil, err) // decode var argsDecode2 mt.MockTestArgs - method, seqID, err = rc.Decode(buf, mt.ToApacheCodec(&argsDecode2)) + method, seqID, err = rc.Decode(buf, bthrift.ToApacheCodec(&argsDecode2)) test.Assert(t, err == nil, err) test.Assert(t, method == "mockMethod1") @@ -95,11 +96,11 @@ func TestSerializer(t *testing.T) { args := mt.NewMockTestArgs() args.Req = req - b, err := rc.Serialize(mt.ToApacheCodec(args)) + b, err := rc.Serialize(bthrift.ToApacheCodec(args)) test.Assert(t, err == nil, err) var args2 mt.MockTestArgs - err = rc.Deserialize(mt.ToApacheCodec(&args2), b) + err = rc.Deserialize(bthrift.ToApacheCodec(&args2), b) test.Assert(t, err == nil, err) test.Assert(t, args2.Req.Msg == req.Msg) diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index b2983d0b3e..94c4fabb2f 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -325,6 +325,11 @@ func (a *Arguments) BuildCmd(out io.Writer) (*exec.Cmd, error) { } } a.ThriftOptions = append(a.ThriftOptions, "package_prefix="+a.PackagePrefix) + + // see README.md in `bthrift` + a.ThriftOptions = append(a.ThriftOptions, + "thrift_import_path=github.com/cloudwego/kitex/pkg/protocol/bthrift/apache") + gas := "go:" + strings.Join(a.ThriftOptions, ",") if a.Verbose { cmd.Args = append(cmd.Args, "-v") From 310ef37cda773c7b2fcef46ca91214aab70779e1 Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Tue, 6 Aug 2024 23:39:36 +0800 Subject: [PATCH 31/70] optimize: add cachekey to discovery event --- client/middlewares.go | 8 +++++--- client/middlewares_test.go | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/client/middlewares.go b/client/middlewares.go index 33c8c671ad..534ef591ae 100644 --- a/client/middlewares.go +++ b/client/middlewares.go @@ -70,9 +70,11 @@ func discoveryEventHandler(name string, bus event.Bus, queue event.Queue) func(d Name: name, Time: now, Extra: map[string]interface{}{ - "Added": wrapInstances(d.Added), - "Updated": wrapInstances(d.Updated), - "Removed": wrapInstances(d.Removed), + "Cacheable": d.Result.Cacheable, + "CacheKey": d.Result.CacheKey, + "Added": wrapInstances(d.Added), + "Updated": wrapInstances(d.Updated), + "Removed": wrapInstances(d.Removed), }, }) } diff --git a/client/middlewares_test.go b/client/middlewares_test.go index ae0caba38a..bb24ce11d0 100644 --- a/client/middlewares_test.go +++ b/client/middlewares_test.go @@ -256,3 +256,26 @@ func BenchmarkResolverMWParallel(b *testing.B) { } }) } + +func TestDiscoveryEventHandler(t *testing.T) { + bus, queue := event.NewEventBus(), event.NewQueue(200) + h := discoveryEventHandler(discovery.ChangeEventName, bus, queue) + ins := []discovery.Instance{discovery.NewInstance("tcp", "addr", 10, nil)} + cacheKey := "testCacheKey" + c := &discovery.Change{ + Result: discovery.Result{ + Cacheable: true, + CacheKey: cacheKey, + }, + Added: ins, + } + h(c) + events := queue.Dump().([]*event.Event) + test.Assert(t, len(events) == 1) + extra, ok := events[0].Extra.(map[string]interface{}) + test.Assert(t, ok) + test.Assert(t, extra["Cacheable"] == true) + test.Assert(t, extra["CacheKey"] == cacheKey) + added := extra["Added"].([]*instInfo) + test.Assert(t, len(added) == 1) +} From b86c624d818a9166c02d77fc8c564a41533aa9da Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 8 Aug 2024 11:08:27 +0800 Subject: [PATCH 32/70] fix: fix GetServerConn interface assert for streamWithMiddleware (#1476) --- pkg/remote/trans/nphttp2/server_conn.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pkg/remote/trans/nphttp2/server_conn.go b/pkg/remote/trans/nphttp2/server_conn.go index 4fac524be5..f124ec1d30 100644 --- a/pkg/remote/trans/nphttp2/server_conn.go +++ b/pkg/remote/trans/nphttp2/server_conn.go @@ -59,11 +59,20 @@ func (c *serverConn) ReadFrame() (hdr, data []byte, err error) { return hdr, data, nil } +// GetServerConn gets the GRPC Connection from server stream. +// This function is only used in server handler for grpc unknown handler proxy: https://www.cloudwego.io/docs/kitex/tutorials/advanced-feature/grpcproxy/ +// And the input stream type should always be streamWithMiddleware. func GetServerConn(st streaming.Stream) (GRPCConn, error) { - serverStream, ok := st.(*stream) + mwStream, ok := st.(*streamWithMiddleware) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to get streamWithMiddleware") + } + + serverStream, ok := mwStream.Stream.(*stream) + if !ok { // err! - return nil, status.Errorf(codes.Internal, "failed to get server conn from stream.") + return nil, status.Errorf(codes.Internal, "failed to get server conn from server stream.") } grpcServerConn, ok := serverStream.conn.(GRPCConn) if !ok { From eedb04267c5a3f95d39d9cb58c24d259cc4f477c Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 8 Aug 2024 15:25:37 +0800 Subject: [PATCH 33/70] fix(trans/netpoll): log when panic in onConnRead (#1486) it happens when a panic happens in grpc code, and then netpoll calls `onConnInactive` which causes deadlock in http2server `Close` --- pkg/remote/trans/netpoll/trans_server.go | 13 ++++++++++--- pkg/remote/trans/netpoll/trans_server_test.go | 17 ++++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pkg/remote/trans/netpoll/trans_server.go b/pkg/remote/trans/netpoll/trans_server.go index f6db948226..527022f6f9 100644 --- a/pkg/remote/trans/netpoll/trans_server.go +++ b/pkg/remote/trans/netpoll/trans_server.go @@ -138,7 +138,7 @@ func (ts *transServer) ConnCount() utils.AtomicInt { // 2. Doesn't need to init RPCInfo if it's not RPC request, such as heartbeat. func (ts *transServer) onConnActive(conn netpoll.Connection) context.Context { ctx := context.Background() - defer transRecover(ctx, conn, "OnActive") + defer transRecover(ctx, conn, "OnActive", false) conn.AddCloseCallback(func(connection netpoll.Connection) error { ts.onConnInactive(ctx, conn) return nil @@ -154,6 +154,10 @@ func (ts *transServer) onConnActive(conn netpoll.Connection) context.Context { } func (ts *transServer) onConnRead(ctx context.Context, conn netpoll.Connection) error { + // in case it's panicked, it may caused by framework, + // try to propagate the err and let it crash. + // we mainly use transRecover for logging + defer transRecover(ctx, conn, "onConnRead", true) err := ts.transHdlr.OnRead(ctx, conn) if err != nil { ts.onError(ctx, err, conn) @@ -166,7 +170,7 @@ func (ts *transServer) onConnRead(ctx context.Context, conn netpoll.Connection) } func (ts *transServer) onConnInactive(ctx context.Context, conn netpoll.Connection) { - defer transRecover(ctx, conn, "OnInactive") + defer transRecover(ctx, conn, "OnInactive", false) ts.connCount.Dec() ts.transHdlr.OnInactive(ctx, conn) } @@ -175,7 +179,7 @@ func (ts *transServer) onError(ctx context.Context, err error, conn netpoll.Conn ts.transHdlr.OnError(ctx, err, conn) } -func transRecover(ctx context.Context, conn netpoll.Connection, funcName string) { +func transRecover(ctx context.Context, conn netpoll.Connection, funcName string, propagatePanic bool) { panicErr := recover() if panicErr != nil { if conn != nil { @@ -183,5 +187,8 @@ func transRecover(ctx context.Context, conn netpoll.Connection, funcName string) } else { klog.CtxErrorf(ctx, "KITEX: panic happened in %s, error=%v\nstack=%s", funcName, panicErr, string(debug.Stack())) } + if propagatePanic { + panic(panicErr) + } } } diff --git a/pkg/remote/trans/netpoll/trans_server_test.go b/pkg/remote/trans/netpoll/trans_server_test.go index 24afb15691..c4d9ce06d7 100644 --- a/pkg/remote/trans/netpoll/trans_server_test.go +++ b/pkg/remote/trans/netpoll/trans_server_test.go @@ -190,7 +190,7 @@ func TestConnOnActiveAndOnInactivePanic(t *testing.T) { // TestOnConnRead test trans_server onConnRead success func TestConnOnRead(t *testing.T) { - // 1. prepare mock data + // prepare mock data var isClosed bool conn := &MockNetpollConn{ Conn: mocks.Conn{ @@ -203,6 +203,8 @@ func TestConnOnRead(t *testing.T) { }, }, } + + // case return err mockErr := errors.New("mock error") transSvr.transHdlr = &mocks.MockSvrTransHandler{ OnReadFunc: func(ctx context.Context, conn net.Conn) error { @@ -210,9 +212,18 @@ func TestConnOnRead(t *testing.T) { }, Opt: transSvr.opt, } - - // 2. test err := transSvr.onConnRead(context.Background(), conn) test.Assert(t, err == nil, err) test.Assert(t, isClosed) + + // case panic + transSvr.transHdlr = &mocks.MockSvrTransHandler{ + OnReadFunc: func(ctx context.Context, conn net.Conn) error { + panic("case panic") + }, + Opt: transSvr.opt, + } + test.Panic(t, func() { + _ = transSvr.onConnRead(context.Background(), conn) + }) } From f66674f6167e623b7e5ab10a08c71fdceaea8cf2 Mon Sep 17 00:00:00 2001 From: Jayant Date: Thu, 8 Aug 2024 15:59:24 +0800 Subject: [PATCH 34/70] test: replace judgement of mem stats of client finalizer by closed count check (#1469) --- client/client_test.go | 2 +- client/service_inline_test.go | 2 +- pkg/generic/binary_test/generic_test.go | 85 ----------------------- pkg/generic/json_test/generic_test.go | 92 ------------------------- pkg/generic/map_test/generic_init.go | 3 +- pkg/generic/map_test/generic_test.go | 64 +++++++---------- 6 files changed, 27 insertions(+), 221 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 64c0e3cfc8..b2039a97a7 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1027,7 +1027,7 @@ func TestClientFinalizer(t *testing.T) { runtime.ReadMemStats(&ms) secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc/2 && secondGCHeapObjects < firstGCHeapObjects/2) + // test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc/2 && secondGCHeapObjects < firstGCHeapObjects/2) } func TestPanicInMiddleware(t *testing.T) { diff --git a/client/service_inline_test.go b/client/service_inline_test.go index 551f42b440..a0109247c6 100644 --- a/client/service_inline_test.go +++ b/client/service_inline_test.go @@ -334,5 +334,5 @@ func TestServiceInlineClientFinalizer(t *testing.T) { runtime.ReadMemStats(&ms) secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc/2 && secondGCHeapObjects < firstGCHeapObjects/2) + // test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc/2 && secondGCHeapObjects < firstGCHeapObjects/2) } diff --git a/pkg/generic/binary_test/generic_test.go b/pkg/generic/binary_test/generic_test.go index da30c62a25..bab2b891b6 100644 --- a/pkg/generic/binary_test/generic_test.go +++ b/pkg/generic/binary_test/generic_test.go @@ -19,15 +19,12 @@ package test import ( "context" "net" - "runtime" - "runtime/debug" "strings" "testing" "time" "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/callopt" "github.com/cloudwego/kitex/client/genericclient" kt "github.com/cloudwego/kitex/internal/mocks/thrift" @@ -191,85 +188,3 @@ func genBinaryReqBuf(method string) []byte { b = append(b, reqMsg...) return b } - -func TestBinaryThriftGenericClientClose(t *testing.T) { - debug.SetGCPercent(-1) - defer debug.SetGCPercent(100) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - - t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - cliCnt := 10000 - clis := make([]genericclient.Client, cliCnt) - for i := 0; i < cliCnt; i++ { - g := generic.BinaryThriftGeneric() - clis[i] = newGenericClient("destServiceName", g, addr, client.WithShortConnection()) - } - - runtime.ReadMemStats(&ms) - preHeapAlloc, preHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After new clients, allocation: %f Mb, Number of allocation: %d\n", preHeapAlloc, preHeapObjects) - - for _, cli := range clis { - _ = cli.Close() - } - runtime.GC() - runtime.ReadMemStats(&ms) - afterGCHeapAlloc, afterGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After close clients and GC be executed, allocation: %f Mb, Number of allocation: %d\n", afterGCHeapAlloc, afterGCHeapObjects) - test.Assert(t, afterGCHeapAlloc < preHeapAlloc && afterGCHeapObjects < preHeapObjects) - - // Trigger the finalizer of kclient be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc/2 < afterGCHeapAlloc && secondGCHeapObjects/2 < afterGCHeapObjects) -} - -func TestBinaryThriftGenericClientFinalizer(t *testing.T) { - debug.SetGCPercent(-1) - defer debug.SetGCPercent(100) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - cliCnt := 10000 - clis := make([]genericclient.Client, cliCnt) - for i := 0; i < cliCnt; i++ { - g := generic.BinaryThriftGeneric() - clis[i] = newGenericClient("destServiceName", g, addr, client.WithShortConnection()) - } - - runtime.ReadMemStats(&ms) - t.Logf("After new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - runtime.GC() - runtime.ReadMemStats(&ms) - firstGCHeapAlloc, firstGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After first GC, allocation: %f Mb, Number of allocation: %d\n", firstGCHeapAlloc, firstGCHeapObjects) - - // Trigger the finalizer of generic client be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc && secondGCHeapObjects < firstGCHeapObjects) - - // Trigger the finalizer of kClient be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - thirdGCHeapAlloc, thirdGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After third GC, allocation: %f Mb, Number of allocation: %d\n", thirdGCHeapAlloc, thirdGCHeapObjects) - test.Assert(t, thirdGCHeapAlloc < secondGCHeapAlloc/2 && thirdGCHeapObjects < secondGCHeapObjects/2) -} - -func mb(byteSize uint64) float32 { - return float32(byteSize) / float32(1024*1024) -} diff --git a/pkg/generic/json_test/generic_test.go b/pkg/generic/json_test/generic_test.go index f955789773..dbe1804719 100644 --- a/pkg/generic/json_test/generic_test.go +++ b/pkg/generic/json_test/generic_test.go @@ -24,8 +24,6 @@ import ( "math" "net" "reflect" - "runtime" - "runtime/debug" "strings" "testing" "time" @@ -56,11 +54,9 @@ func TestRun(t *testing.T) { t.Run("TestThriftVoidMethodWithDynamicGo", testThriftVoidMethodWithDynamicGo) t.Run("TestThrift2NormalServer", testThrift2NormalServer) t.Run("TestThriftException", testThriftException) - t.Run("TestJSONThriftGenericClientClose", testJSONThriftGenericClientClose) t.Run("TestThriftRawBinaryEcho", testThriftRawBinaryEcho) t.Run("TestThriftBase64BinaryEcho", testThriftBase64BinaryEcho) t.Run("TestRegression", testRegression) - t.Run("TestJSONThriftGenericClientFinalizer", testJSONThriftGenericClientFinalizer) t.Run("TestParseModeWithDynamicGo", testParseModeWithDynamicGo) } @@ -622,90 +618,6 @@ func initMockServer(t *testing.T, handler kt.Mock, address string) server.Server return svr } -func testJSONThriftGenericClientClose(t *testing.T) { - debug.SetGCPercent(-1) - defer debug.SetGCPercent(100) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - - t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - clientCnt := 1000 - clis := make([]genericclient.Client, clientCnt) - for i := 0; i < clientCnt; i++ { - p, err := generic.NewThriftFileProvider("./idl/mock.thrift") - test.Assertf(t, err == nil, "generic NewThriftFileProvider failed, err=%v", err) - g, err := generic.JSONThriftGeneric(p) - test.Assertf(t, err == nil, "generic JSONThriftGeneric failed, err=%v", err) - clis[i] = newGenericClient(transport.TTHeader, "destServiceName", g, "127.0.0.1:8129") - } - - runtime.ReadMemStats(&ms) - preHeapAlloc, preHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After new clients, allocation: %f Mb, Number of allocation: %d\n", preHeapAlloc, preHeapObjects) - - for _, cli := range clis { - _ = cli.Close() - } - runtime.GC() - runtime.ReadMemStats(&ms) - afterGCHeapAlloc, afterGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After close clients and GC be executed, allocation: %f Mb, Number of allocation: %d\n", afterGCHeapAlloc, afterGCHeapObjects) - test.Assert(t, afterGCHeapAlloc < preHeapAlloc && afterGCHeapObjects < preHeapObjects) - - // Trigger the finalizer of kclient be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc/2 < afterGCHeapAlloc && secondGCHeapObjects/2 < afterGCHeapObjects) -} - -func testJSONThriftGenericClientFinalizer(t *testing.T) { - debug.SetGCPercent(-1) - defer debug.SetGCPercent(100) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - clientCnt := 1000 - clis := make([]genericclient.Client, clientCnt) - for i := 0; i < clientCnt; i++ { - p, err := generic.NewThriftFileProvider("./idl/mock.thrift") - test.Assert(t, err == nil, "generic NewThriftFileProvider failed, err=%v", err) - g, err := generic.JSONThriftGeneric(p) - test.Assert(t, err == nil, "generic JSONThriftGeneric failed, err=%v", err) - clis[i] = newGenericClient(transport.TTHeader, "destServiceName", g, "127.0.0.1:8130") - } - - runtime.ReadMemStats(&ms) - t.Logf("After new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - runtime.GC() - runtime.ReadMemStats(&ms) - firstGCHeapAlloc, firstGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After first GC, allocation: %f Mb, Number of allocation: %d\n", firstGCHeapAlloc, firstGCHeapObjects) - - // Trigger the finalizer of generic client be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc && secondGCHeapObjects < firstGCHeapObjects) - - // Trigger the finalizer of kClient be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - thirddGCHeapAlloc, thirdGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After third GC, allocation: %f Mb, Number of allocation: %d\n", thirddGCHeapAlloc, thirdGCHeapObjects) - test.Assert(t, thirddGCHeapAlloc < secondGCHeapAlloc/2 && thirdGCHeapObjects < secondGCHeapObjects/2) -} - func testParseModeWithDynamicGo(t *testing.T) { addr := test.GetLocalAddress() thrift.SetDefaultParseMode(thrift.FirstServiceOnly) @@ -722,7 +634,3 @@ func testParseModeWithDynamicGo(t *testing.T) { svr.Stop() } - -func mb(byteSize uint64) float32 { - return float32(byteSize) / float32(1024*1024) -} diff --git a/pkg/generic/map_test/generic_init.go b/pkg/generic/map_test/generic_init.go index 2b92a303a3..253a6397f5 100644 --- a/pkg/generic/map_test/generic_init.go +++ b/pkg/generic/map_test/generic_init.go @@ -53,8 +53,7 @@ var reqMsg = map[string]interface{}{ var errResp = "Test Error" -func newGenericClient(destService string, g generic.Generic, targetIPPort string) genericclient.Client { - var opts []client.Option +func newGenericClient(destService string, g generic.Generic, targetIPPort string, opts ...client.Option) genericclient.Client { opts = append(opts, client.WithHostPorts(targetIPPort), client.WithTransportProtocol(transport.TTHeader)) genericCli, _ := genericclient.NewClient(destService, g, opts...) return genericCli diff --git a/pkg/generic/map_test/generic_test.go b/pkg/generic/map_test/generic_test.go index a76a93a2b3..fe761de8ea 100644 --- a/pkg/generic/map_test/generic_test.go +++ b/pkg/generic/map_test/generic_test.go @@ -24,9 +24,11 @@ import ( "runtime" "runtime/debug" "strings" + "sync/atomic" "testing" "time" + "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/callopt" "github.com/cloudwego/kitex/client/genericclient" kt "github.com/cloudwego/kitex/internal/mocks/thrift" @@ -433,45 +435,14 @@ func initMockServer(t *testing.T, handler kt.Mock, address string) server.Server return svr } -func TestMapThriftGenericClientClose(t *testing.T) { - debug.SetGCPercent(-1) - defer debug.SetGCPercent(100) - - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - - t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) - - clientCnt := 1000 - clis := make([]genericclient.Client, clientCnt) - for i := 0; i < clientCnt; i++ { - p, err := generic.NewThriftFileProvider("./idl/mock.thrift") - test.Assert(t, err == nil, "generic NewThriftFileProvider failed, err=%v", err) - g, err := generic.MapThriftGeneric(p) - test.Assert(t, err == nil, "generic MapThriftGeneric failed, err=%v", err) - clis[i] = newGenericClient("destServiceName", g, "127.0.0.1:9020") - } - - runtime.ReadMemStats(&ms) - preHeapAlloc, preHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After new clients, allocation: %f Mb, Number of allocation: %d\n", preHeapAlloc, preHeapObjects) - - for _, cli := range clis { - _ = cli.Close() - } - runtime.GC() - runtime.ReadMemStats(&ms) - afterGCHeapAlloc, afterGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After close clients and GC be executed, allocation: %f Mb, Number of allocation: %d\n", afterGCHeapAlloc, afterGCHeapObjects) - test.Assert(t, afterGCHeapAlloc < preHeapAlloc && afterGCHeapObjects < preHeapObjects) +type testGeneric struct { + cb func() + generic.Generic +} - // Trigger the finalizer of kclient be executed - time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed - runtime.GC() - runtime.ReadMemStats(&ms) - secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects - t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) - test.Assert(t, secondGCHeapAlloc/2 < afterGCHeapAlloc && secondGCHeapObjects/2 < afterGCHeapObjects) +func (g *testGeneric) Close() error { + g.cb() + return g.Generic.Close() } func TestMapThriftGenericClientFinalizer(t *testing.T) { @@ -483,13 +454,24 @@ func TestMapThriftGenericClientFinalizer(t *testing.T) { t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) clientCnt := 1000 + var genericCloseCnt int32 + var kitexClientCloseCnt int32 clis := make([]genericclient.Client, clientCnt) for i := 0; i < clientCnt; i++ { p, err := generic.NewThriftFileProvider("./idl/mock.thrift") test.Assert(t, err == nil, "generic NewThriftFileProvider failed, err=%v", err) g, err := generic.MapThriftGeneric(p) test.Assert(t, err == nil, "generic MapThriftGeneric failed, err=%v", err) - clis[i] = newGenericClient("destServiceName", g, "127.0.0.1:9021") + g = &testGeneric{ + cb: func() { + atomic.AddInt32(&genericCloseCnt, 1) + }, + Generic: g, + } + clis[i] = newGenericClient("destServiceName", g, "127.0.0.1:9021", client.WithCloseCallbacks(func() error { + atomic.AddInt32(&kitexClientCloseCnt, 1) + return nil + })) } runtime.ReadMemStats(&ms) @@ -502,6 +484,7 @@ func TestMapThriftGenericClientFinalizer(t *testing.T) { // Trigger the finalizer of generic client be executed time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed + test.Assert(t, atomic.LoadInt32(&genericCloseCnt) == int32(clientCnt)) runtime.GC() runtime.ReadMemStats(&ms) secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects @@ -510,11 +493,12 @@ func TestMapThriftGenericClientFinalizer(t *testing.T) { // Trigger the finalizer of kClient be executed time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed + test.Assert(t, atomic.LoadInt32(&kitexClientCloseCnt) == int32(clientCnt)) runtime.GC() runtime.ReadMemStats(&ms) thirdGCHeapAlloc, thirdGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects t.Logf("After third GC, allocation: %f Mb, Number of allocation: %d\n", thirdGCHeapAlloc, thirdGCHeapObjects) - test.Assert(t, thirdGCHeapAlloc < secondGCHeapAlloc/2 && thirdGCHeapObjects < secondGCHeapObjects/2) + // test.Assert(t, thirdGCHeapAlloc < secondGCHeapAlloc/2 && thirdGCHeapObjects < secondGCHeapObjects/2) } func mb(byteSize uint64) float32 { From 014804ab1d9ed6475b468b190e50681dd479ae01 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 8 Aug 2024 16:52:03 +0800 Subject: [PATCH 35/70] feat(tool): embed thriftgo into kitex tool (#1479) --- go.mod | 2 +- go.sum | 4 +- tool/cmd/kitex/args/args.go | 11 ++- tool/cmd/kitex/args/version_requirements.go | 3 +- tool/cmd/kitex/main.go | 74 ++++++------------ tool/cmd/kitex/sdk/kitex_sdk.go | 75 +++++++++++++------ tool/cmd/kitex/utils/utils.go | 74 ++++++++++++++++++ tool/internal_pkg/generator/generator.go | 1 + tool/internal_pkg/generator/generator_test.go | 2 +- 9 files changed, 164 insertions(+), 82 deletions(-) create mode 100644 tool/cmd/kitex/utils/utils.go diff --git a/go.mod b/go.mod index 72b9a05c98..7085425afd 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 - github.com/cloudwego/thriftgo v0.3.15 + github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 github.com/jhump/protoreflect v1.8.2 diff --git a/go.sum b/go.sum index 8a5fe20fc2..d4a0963dcc 100644 --- a/go.sum +++ b/go.sum @@ -49,8 +49,8 @@ github.com/cloudwego/netpoll v0.6.3/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzL github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= github.com/cloudwego/thriftgo v0.3.6/go.mod h1:29ukiySoAMd0vXMYIduAY9dph/7dmChvOS11YLotFb8= -github.com/cloudwego/thriftgo v0.3.15 h1:yB/DDGjeSjliyidMVBjKhGl9RgE4M8iVIz5dKpAIyUs= -github.com/cloudwego/thriftgo v0.3.15/go.mod h1:R4a+4aVDI0V9YCTfpNgmvbkq/9ThKgF7Om8Z0I36698= +github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 h1:KiEGBvsyAyUrFrpEi/e77K0SWTLK8FMHhSQ5c9kFJic= +github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083/go.mod h1:R4a+4aVDI0V9YCTfpNgmvbkq/9ThKgF7Om8Z0I36698= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 94c4fabb2f..4e81845c05 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -126,6 +126,9 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { "Specify a protocol for codec") f.BoolVar(&a.NoDependencyCheck, "no-dependency-check", false, "Skip dependency checking.") + f.BoolVar(&a.Rapid, "rapid", false, + "Use embedded thriftgo.") + a.RecordCmd = os.Args a.Version = version a.ThriftOptions = append(a.ThriftOptions, @@ -139,9 +142,6 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { "no_processor", ) - for _, e := range a.extends { - e.Apply(f) - } f.Usage = func() { fmt.Fprintf(os.Stderr, `Version %s Usage: %s [flags] IDL @@ -152,6 +152,11 @@ Flags: `, a.Version, os.Args[0], cmdExample) f.PrintDefaults() } + + for _, e := range a.extends { + e.Apply(f) + } + return f } diff --git a/tool/cmd/kitex/args/version_requirements.go b/tool/cmd/kitex/args/version_requirements.go index aa432f3a51..72ef19ce26 100644 --- a/tool/cmd/kitex/args/version_requirements.go +++ b/tool/cmd/kitex/args/version_requirements.go @@ -14,5 +14,4 @@ package args -// todo thriftgo sdk 功能 v0.3.13 发布打 tag -var requiredThriftGoVersion = "v0.3.6" +var requiredThriftGoVersion = "v0.3.15" diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 290f0ef1ae..ef6aa53b8e 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -17,12 +17,14 @@ package main import ( "bytes" "flag" - "io/ioutil" "os" - "os/exec" "path/filepath" "strings" + "github.com/cloudwego/kitex/tool/cmd/kitex/utils" + + "github.com/cloudwego/kitex/tool/cmd/kitex/sdk" + "github.com/cloudwego/kitex" kargs "github.com/cloudwego/kitex/tool/cmd/kitex/args" "github.com/cloudwego/kitex/tool/cmd/kitex/versions" @@ -79,8 +81,11 @@ func main() { // run as kitex err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) if err != nil { - log.Warn(err.Error()) - os.Exit(2) + if err.Error() != "flag: help requested" { + log.Warn(err.Error()) + os.Exit(2) + } + os.Exit(0) } if !args.NoDependencyCheck { // check dependency compatibility between kitex cmd tool and dependency in go.mod @@ -95,12 +100,18 @@ func main() { log.Warn(err) os.Exit(1) } - err = kargs.ValidateCMD(cmd.Path, args.IDLType) - if err != nil { - log.Warn(err) - os.Exit(1) + + if args.IDLType == "thrift" && args.Rapid { + err = sdk.InvokeThriftgoBySDK(curpath, cmd) + } else { + err = kargs.ValidateCMD(cmd.Path, args.IDLType) + if err != nil { + log.Warn(err) + os.Exit(1) + } + err = cmd.Run() } - err = cmd.Run() + if err != nil { if args.Use != "" { out := strings.TrimSpace(out.String()) @@ -108,50 +119,9 @@ func main() { goto NormalExit } } + log.Warn(err) os.Exit(1) } NormalExit: - if args.IDLType == "thrift" { - cmd := "go mod edit -replace github.com/apache/thrift=github.com/apache/thrift@v0.13.0" - argv := strings.Split(cmd, " ") - err := exec.Command(argv[0], argv[1:]...).Run() - - res := "Done" - if err != nil { - res = err.Error() - } - log.Warn("Adding apache/thrift@v0.13.0 to go.mod for generated code ..........", res) - } - - // remove kitex.yaml generated from v0.4.4 which is renamed as kitex_info.yaml - if args.ServiceName != "" { - DeleteKitexYaml() - } - - // If hessian option is java_extension, replace *java.Object to java.Object - if thriftgo.EnableJavaExtension(args.Config) { - if err = thriftgo.Hessian2PatchByReplace(args.Config, ""); err != nil { - log.Warn("replace java object fail, you can fix it then regenerate", err) - } - } -} - -func DeleteKitexYaml() { - // try to read kitex.yaml - data, err := ioutil.ReadFile("kitex.yaml") - if err != nil { - if !os.IsNotExist(err) { - log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, you can delete it or ignore it.") - } - return - } - // if kitex.yaml exists, check content and delete it. - if strings.HasPrefix(string(data), "kitexinfo:") { - err = os.Remove("kitex.yaml") - if err != nil { - log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, you can delete it or ignore it.") - } else { - log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, so it's automatically deleted now.") - } - } + utils.OnKitexToolNormalExit(args) } diff --git a/tool/cmd/kitex/sdk/kitex_sdk.go b/tool/cmd/kitex/sdk/kitex_sdk.go index 6ba0633a6e..1cd6e3d445 100644 --- a/tool/cmd/kitex/sdk/kitex_sdk.go +++ b/tool/cmd/kitex/sdk/kitex_sdk.go @@ -18,8 +18,11 @@ import ( "bytes" "flag" "fmt" + "os/exec" "strings" + "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/thriftgo/plugin" "github.com/cloudwego/thriftgo/sdk" @@ -78,32 +81,36 @@ func GetKiteXSDKPlugin(pwd string, rawKiteXArgs []string) (*KiteXSDKPlugin, erro kitexPlugin := &KiteXSDKPlugin{} - thriftgoParams := []string{} - findKitex := false - kitexParams := []string{} - - for i, arg := range cmd.Args { - if i == 0 { - continue - } - if arg == "-p" { - findKitex = true - continue - } - if findKitex { - kitexParams = strings.Split(arg, ",") - findKitex = false - } else { - thriftgoParams = append(thriftgoParams, arg) - } + kitexPlugin.ThriftgoParams, kitexPlugin.KitexParams, err = ParseKitexCmd(cmd) + if err != nil { + return nil, err } + kitexPlugin.Pwd = pwd - kitexPlugin.ThriftgoParams = thriftgoParams - kitexPlugin.KitexParams = kitexParams + return kitexPlugin, nil +} + +// InvokeThriftgoBySDK is for kitex tool main.go +func InvokeThriftgoBySDK(pwd string, cmd *exec.Cmd) (err error) { + kitexPlugin := &KiteXSDKPlugin{} + + kitexPlugin.ThriftgoParams, kitexPlugin.KitexParams, err = ParseKitexCmd(cmd) + if err != nil { + return err + } kitexPlugin.Pwd = pwd - return kitexPlugin, nil + s := []plugin.SDKPlugin{kitexPlugin} + + err = sdk.RunThriftgoAsSDK(pwd, s, kitexPlugin.GetThriftgoParameters()...) + // when execute thriftgo as function, log will be unexpectedly replaced (for old code by design), so we have to change it back. + log.SetDefaultLogger(log.Logger{ + Println: fmt.Fprintln, + Printf: fmt.Fprintf, + }) + + return err } type KiteXSDKPlugin struct { @@ -127,3 +134,29 @@ func (k *KiteXSDKPlugin) GetPluginParameters() []string { func (k *KiteXSDKPlugin) GetThriftgoParameters() []string { return k.ThriftgoParams } + +func ParseKitexCmd(cmd *exec.Cmd) (thriftgoParams, kitexParams []string, err error) { + cmdArgs := cmd.Args + // thriftgo -r -o kitex_gen -g go:xxx -p kitex=xxxx -p otherplugin xxx.thrift + // ignore first argument, and remove -p kitex=xxxx + + thriftgoParams = []string{} + kitexParams = []string{} + if len(cmdArgs) < 1 { + return nil, nil, fmt.Errorf("cmd args too short: %s", cmdArgs) + } + + for i := 1; i < len(cmdArgs); i++ { + arg := cmdArgs[i] + if arg == "-p" && i+1 < len(cmdArgs) { + pluginArgs := cmdArgs[i+1] + if strings.HasPrefix(pluginArgs, "kitex") { + kitexParams = strings.Split(pluginArgs, ",") + i++ + continue + } + } + thriftgoParams = append(thriftgoParams, arg) + } + return thriftgoParams, kitexParams, nil +} diff --git a/tool/cmd/kitex/utils/utils.go b/tool/cmd/kitex/utils/utils.go new file mode 100644 index 0000000000..cfe1399a46 --- /dev/null +++ b/tool/cmd/kitex/utils/utils.go @@ -0,0 +1,74 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "io/ioutil" + "os" + "os/exec" + "strings" + + kargs "github.com/cloudwego/kitex/tool/cmd/kitex/args" + "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/kitex/tool/internal_pkg/pluginmode/thriftgo" +) + +func OnKitexToolNormalExit(args kargs.Arguments) { + if args.IDLType == "thrift" { + cmd := "go mod edit -replace github.com/apache/thrift=github.com/apache/thrift@v0.13.0" + argv := strings.Split(cmd, " ") + err := exec.Command(argv[0], argv[1:]...).Run() + + res := "Done" + if err != nil { + res = err.Error() + } + log.Warn("Adding apache/thrift@v0.13.0 to go.mod for generated code ..........", res) + } + + log.Warn("Code Generation is Done!") + + // remove kitex.yaml generated from v0.4.4 which is renamed as kitex_info.yaml + if args.ServiceName != "" { + DeleteKitexYaml() + } + + // If hessian option is java_extension, replace *java.Object to java.Object + if thriftgo.EnableJavaExtension(args.Config) { + if err := thriftgo.Hessian2PatchByReplace(args.Config, ""); err != nil { + log.Warn("replace java object fail, you can fix it then regenerate", err) + } + } +} + +func DeleteKitexYaml() { + // try to read kitex.yaml + data, err := ioutil.ReadFile("kitex.yaml") + if err != nil { + if !os.IsNotExist(err) { + log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, you can delete it or ignore it.") + } + return + } + // if kitex.yaml exists, check content and delete it. + if strings.HasPrefix(string(data), "kitexinfo:") { + err = os.Remove("kitex.yaml") + if err != nil { + log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, you can delete it or ignore it.") + } else { + log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, so it's automatically deleted now.") + } + } +} diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 5f792beee1..7563bbba43 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -140,6 +140,7 @@ type Config struct { HandlerReturnKeepResp bool NoDependencyCheck bool + Rapid bool } // Pack packs the Config into a slice of "key=val" strings. diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index cccee89e11..43f9b1ac42 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -69,7 +69,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false"}, }, } for _, tt := range tests { From 9249bb5b2ddf5eb21ac0c071124f3f378a29b9d8 Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Mon, 12 Aug 2024 11:28:44 +0800 Subject: [PATCH 36/70] build: adapt to go1.23rc2 (#1468) --- go.mod | 9 +- go.sum | 100 +++--------------- pkg/generic/map_test/generic_test.go | 1 - pkg/remote/codec/thrift/thrift_frugal.go | 7 -- pkg/remote/codec/thrift/thrift_frugal_test.go | 7 -- pkg/remote/codec/thrift/thrift_others.go | 46 -------- pkg/utils/strings_test.go | 8 +- 7 files changed, 21 insertions(+), 157 deletions(-) delete mode 100644 pkg/remote/codec/thrift/thrift_others.go diff --git a/go.mod b/go.mod index 7085425afd..38bca68ffa 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.17 require ( github.com/apache/thrift v0.13.0 github.com/bytedance/gopkg v0.1.0 - github.com/bytedance/sonic v1.11.8 + github.com/bytedance/sonic v1.12.0 github.com/cloudwego/configmanager v0.2.2 - github.com/cloudwego/dynamicgo v0.2.9 + github.com/cloudwego/dynamicgo v0.3.0 github.com/cloudwego/fastpb v0.0.4 - github.com/cloudwego/frugal v0.1.15 + github.com/cloudwego/frugal v0.2.0 github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8 github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 @@ -30,7 +30,7 @@ require ( ) require ( - github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -42,7 +42,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/oleiade/lane v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index d4a0963dcc..124831aaf1 100644 --- a/go.sum +++ b/go.sum @@ -1,27 +1,18 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= -git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= -github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= -github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic v1.11.8 h1:Zw/j1KfiS+OYTi9lyB3bb0CFxPJVkM17k1wyDG32LRA= -github.com/bytedance/sonic v1.11.8/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic v1.12.0 h1:YGPgxF9xzaCNvd/ZKdQ28yRovhfMFZQjuk6fKBzZ3ls= +github.com/bytedance/sonic v1.12.0/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= +github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= @@ -31,15 +22,15 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/configmanager v0.2.2 h1:sVrJB8gWYTlPV2OS3wcgJSO9F2/9Zbkmcm1Z7jempOU= github.com/cloudwego/configmanager v0.2.2/go.mod h1:ppiyU+5TPLonE8qMVi/pFQk2eL3Q4P7d4hbiNJn6jwI= -github.com/cloudwego/dynamicgo v0.2.9 h1:MHGyGmdFT8iMOsM5S9iutjZB0csu2LupsTTHyi6a8pY= -github.com/cloudwego/dynamicgo v0.2.9/go.mod h1:F3jlbPmlNzhcuDMXwZoBJ7rJKpg2iE+TnIy9pSJiGzs= +github.com/cloudwego/dynamicgo v0.3.0 h1:2/jOD3cMn8YVWGmVybrn74YulmhxW8d4BPyy9pja5eo= +github.com/cloudwego/dynamicgo v0.3.0/go.mod h1:vPHEegW2xqjuDE8NAui+2D93RivFv18eWsyD9VRtORM= github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4mg= github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= -github.com/cloudwego/frugal v0.1.15 h1:LC55UJKhQPMFVjDPbE+LJcF7etZjSx6uokG1tk0wPK0= -github.com/cloudwego/frugal v0.1.15/go.mod h1:26kU1r18vA8vRg12c66XPDlfv1GQHDbE1RpusipXfcI= +github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= +github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= +github.com/cloudwego/gopkg v0.1.0/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8 h1:kQPjddHw5Dufci/vfiRGMN3Uhx12XWqNpk1JdQ4Tjy0= github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= -github.com/cloudwego/iasm v0.0.9/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= @@ -52,6 +43,7 @@ github.com/cloudwego/thriftgo v0.3.6/go.mod h1:29ukiySoAMd0vXMYIduAY9dph/7dmChvO github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 h1:KiEGBvsyAyUrFrpEi/e77K0SWTLK8FMHhSQ5c9kFJic= github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083/go.mod h1:R4a+4aVDI0V9YCTfpNgmvbkq/9ThKgF7Om8Z0I36698= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -65,20 +57,7 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= -github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= -github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -github.com/go-fonts/liberation v0.2.0/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= -github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= -github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= -github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -100,9 +79,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= -github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 h1:mpL/HvfIgIejhVwAfxBQkwEjlhP5o0O9RAeTAjpwzxc= github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3/go.mod h1:gSuNB+gJaOiQKLEZ+q+PK9Mq3SOzhRcw2GsGS/FhYDk= @@ -116,8 +94,6 @@ github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4 github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= -github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -127,8 +103,9 @@ github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgSh github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 h1:uiS4zKYKJVj5F3ID+5iylfKPsEQmBEOucSD9Vgmn0i0= @@ -136,24 +113,14 @@ github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5/go.mod h1:I8AX+yW//L github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso= -github.com/oleiade/lane v1.0.1 h1:hXofkn7GEOubzTwNpeL9MaNy8WxolCYb9cInAIeqShU= -github.com/oleiade/lane v1.0.1/go.mod h1:IyTkraa4maLfjq/GmHR+Dxb4kCMtEGeb+qmhlrQ5Mk4= -github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY= -github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= -github.com/phpdave11/gofpdi v1.0.13/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= -github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZKsJ8yyVxGRWYNEm9oFB8ieLgKFnamEyDmSA0BRk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -176,7 +143,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.2.0 h1:W1sUEHXiJTfjaFJ5SLo0N6lZn+0eO5gWD1MFeTGqQEY= @@ -187,37 +153,16 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE= -golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20210607152325-775e3b0c77b9/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/image v0.0.0-20220302094943-723b81ca9867/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -232,7 +177,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -252,20 +196,16 @@ golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210304124612-50617c2ba197/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -283,21 +223,17 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -307,7 +243,6 @@ golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= @@ -316,14 +251,6 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= -gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= -gonum.org/v1/gonum v0.12.0/go.mod h1:73TDxJfAAHeA8Mk9mf8NlIppyhQNo5GLTcYeqgo2lvY= -gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= -gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= -gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY= -gonum.org/v1/plot v0.10.1/go.mod h1:VZW5OlhkL1mysU9vaqNHnsy86inf6Ot+jB3r+BczCEo= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -363,6 +290,5 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/pkg/generic/map_test/generic_test.go b/pkg/generic/map_test/generic_test.go index fe761de8ea..3ad0c45baf 100644 --- a/pkg/generic/map_test/generic_test.go +++ b/pkg/generic/map_test/generic_test.go @@ -448,7 +448,6 @@ func (g *testGeneric) Close() error { func TestMapThriftGenericClientFinalizer(t *testing.T) { debug.SetGCPercent(-1) defer debug.SetGCPercent(100) - var ms runtime.MemStats runtime.ReadMemStats(&ms) t.Logf("Before new clients, allocation: %f Mb, Number of allocation: %d\n", mb(ms.HeapAlloc), ms.HeapObjects) diff --git a/pkg/remote/codec/thrift/thrift_frugal.go b/pkg/remote/codec/thrift/thrift_frugal.go index 13be190133..aeae5c19ad 100644 --- a/pkg/remote/codec/thrift/thrift_frugal.go +++ b/pkg/remote/codec/thrift/thrift_frugal.go @@ -1,10 +1,3 @@ -//go:build (amd64 || arm64) && !windows && go1.16 && !go1.23 && !disablefrugal -// +build amd64 arm64 -// +build !windows -// +build go1.16 -// +build !go1.23 -// +build !disablefrugal - /* * Copyright 2021 CloudWeGo Authors * diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index 0393b5e5be..ba2f83f999 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -1,10 +1,3 @@ -//go:build (amd64 || arm64) && !windows && go1.16 && !go1.23 && !disablefrugal -// +build amd64 arm64 -// +build !windows -// +build go1.16 -// +build !go1.23 -// +build !disablefrugal - /* * Copyright 2021 CloudWeGo Authors * diff --git a/pkg/remote/codec/thrift/thrift_others.go b/pkg/remote/codec/thrift/thrift_others.go deleted file mode 100644 index 3895cd46f2..0000000000 --- a/pkg/remote/codec/thrift/thrift_others.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build (!amd64 && !arm64) || windows || !go1.16 || go1.23 || disablefrugal -// +build !amd64,!arm64 windows !go1.16 go1.23 disablefrugal - -/* - * Copyright 2021 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "github.com/cloudwego/kitex/pkg/remote" -) - -// hyperMarshalAvailable indicates that if high priority message codec is available. -func hyperMarshalAvailable(data interface{}) bool { - return false -} - -// hyperMessageUnmarshalAvailable indicates that if high priority message codec is available. -func (c thriftCodec) hyperMessageUnmarshalAvailable(data interface{}, payloadLen int) bool { - return false -} - -func (c thriftCodec) hyperMarshal(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, data interface{}) error { - panic("unreachable code") -} - -func (c thriftCodec) hyperMarshalBody(data interface{}) (buf []byte, err error) { - panic("unreachable code") -} - -func (c thriftCodec) hyperMessageUnmarshal(buf []byte, data interface{}) error { - panic("unreachable code") -} diff --git a/pkg/utils/strings_test.go b/pkg/utils/strings_test.go index f253d33529..b67aa50042 100644 --- a/pkg/utils/strings_test.go +++ b/pkg/utils/strings_test.go @@ -26,13 +26,13 @@ import ( func TestStringBuilder(t *testing.T) { sb := &StringBuilder{} sb.Grow(4) - test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Cap() >= 4) test.Assert(t, sb.Len() == 0) sb.WriteString("1") sb.WriteByte('2') sb.WriteRune(rune('3')) sb.Write([]byte("4")) - test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Cap() >= 4) test.Assert(t, sb.Len() == 4) test.Assert(t, sb.String() == "1234") sb.Reset() @@ -42,13 +42,13 @@ func TestStringBuilder(t *testing.T) { sb.WithLocked(func(sb *strings.Builder) error { sb.Grow(4) - test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Cap() >= 4) test.Assert(t, sb.Len() == 0) sb.WriteString("1") sb.WriteByte('2') sb.WriteRune(rune('3')) sb.Write([]byte("4")) - test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Cap() >= 4) test.Assert(t, sb.Len() == 4) test.Assert(t, sb.String() == "1234") sb.Reset() From d84bca28c8a48c96c73f21cecd1f88120592d978 Mon Sep 17 00:00:00 2001 From: Jayant Date: Mon, 12 Aug 2024 20:42:21 +0800 Subject: [PATCH 37/70] fix(codec): wrap trans error for apache thrift read error (#1489) --- pkg/remote/codec/thrift/thrift_data.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index b6243107bd..743c713a9e 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -208,10 +208,14 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error { // decodeBasicThriftData decode thrift body the old way (slow) func decodeBasicThriftData(trans remote.ByteBuffer, data interface{}) error { - if err := verifyUnmarshalBasicThriftDataType(data); err != nil { + var err error + if err = verifyUnmarshalBasicThriftDataType(data); err != nil { return err } - return apache.ThriftRead(apache.NewDefaultTransport(trans), data) + if err = apache.ThriftRead(apache.NewDefaultTransport(trans), data); err != nil { + return remote.NewTransError(remote.ProtocolError, err) + } + return nil } func getSkippedStructBuffer(trans remote.ByteBuffer) ([]byte, error) { From 1f8ccf43caad6cd0c124f6f9cfcb2f6a839d2253 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:54:22 +0800 Subject: [PATCH 38/70] fix(generic): fix a generic serviceInfo compatible issue (#1487) --- client/genericclient/client.go | 5 +- pkg/generic/generic_service.go | 47 ++++++++++++------- pkg/generic/generic_service_test.go | 2 +- pkg/generic/map_test/generic_test.go | 28 +++++++++++ .../trans/nphttp2/grpc/transport_test.go | 26 +++++----- .../trans/nphttp2/status/status_test.go | 4 +- server/genericserver/server.go | 2 +- server/genericserver/server_test.go | 2 +- 8 files changed, 79 insertions(+), 37 deletions(-) diff --git a/client/genericclient/client.go b/client/genericclient/client.go index c40d69ebaa..e032371df7 100644 --- a/client/genericclient/client.go +++ b/client/genericclient/client.go @@ -31,12 +31,15 @@ var _ Client = &genericServiceClient{} // NewClient create a generic client func NewClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { - svcInfo := generic.ServiceInfoWithCodec(g) + svcInfo := generic.ServiceInfoWithGeneric(g) return NewClientWithServiceInfo(destService, g, svcInfo, opts...) } // NewClientWithServiceInfo create a generic client with serviceInfo func NewClientWithServiceInfo(destService string, g generic.Generic, svcInfo *serviceinfo.ServiceInfo, opts ...client.Option) (Client, error) { + if isDeprecated, ok := svcInfo.Extra[generic.DeprecatedGenericServiceInfoAPIKey].(bool); ok && isDeprecated { + svcInfo.Methods, svcInfo.ServiceName = generic.GetMethodInfo(g.MessageReaderWriter(), g.IDLServiceName()) + } var options []client.Option options = append(options, client.WithGeneric(g)) options = append(options, client.WithDestService(destService)) diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index 74d94264b9..dd6535550a 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -29,29 +29,49 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) +// TODO(marina.sakai): remove in v0.12.0 +const DeprecatedGenericServiceInfoAPIKey = "deprecated_generic_service_info_api" + // Service generic service interface type Service interface { // GenericCall handle the generic call GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) } -// ServiceInfoWithCodec create a generic ServiceInfo with CodecInfo -func ServiceInfoWithCodec(g Generic) *serviceinfo.ServiceInfo { - return newServiceInfo(g.PayloadCodecType(), g.MessageReaderWriter(), g.IDLServiceName()) +// ServiceInfoWithGeneric create a generic ServiceInfo +func ServiceInfoWithGeneric(g Generic) *serviceinfo.ServiceInfo { + return newServiceInfo(g.PayloadCodecType(), g.MessageReaderWriter(), g.IDLServiceName(), true) } -// Deprecated: it's not used by kitex anymore. +// Deprecated: Replaced by ServiceInfoWithGeneric, this method will be removed in v0.12.0 // ServiceInfo create a generic ServiceInfo +// TODO(marina.sakai): remove in v0.12.0 func ServiceInfo(pcType serviceinfo.PayloadCodec) *serviceinfo.ServiceInfo { - return newServiceInfo(pcType, nil, "") + return newServiceInfo(pcType, nil, "", false) } -func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interface{}, serviceName string) *serviceinfo.ServiceInfo { +func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interface{}, serviceName string, withGeneric bool) *serviceinfo.ServiceInfo { handlerType := (*Service)(nil) - var methods map[string]serviceinfo.MethodInfo - var svcName string + methods, svcName := GetMethodInfo(messageReaderWriter, serviceName) + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: svcName, + HandlerType: handlerType, + Methods: methods, + PayloadCodec: pcType, + Extra: make(map[string]interface{}), + } + svcInfo.Extra["generic"] = true + // TODO(marina.sakai): remove in v0.12.0 + if !withGeneric { + svcInfo.Extra[DeprecatedGenericServiceInfoAPIKey] = true + } + return svcInfo +} + +// GetMethodInfo is only used in kitex, please DON'T USE IT. This method may be removed in the future +func GetMethodInfo(messageReaderWriter interface{}, serviceName string) (methods map[string]serviceinfo.MethodInfo, svcName string) { if messageReaderWriter == nil { // note: binary generic cannot be used with multi-service feature svcName = serviceinfo.GenericService @@ -69,16 +89,7 @@ func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interfa ), } } - - svcInfo := &serviceinfo.ServiceInfo{ - ServiceName: svcName, - HandlerType: handlerType, - Methods: methods, - PayloadCodec: pcType, - Extra: make(map[string]interface{}), - } - svcInfo.Extra["generic"] = true - return svcInfo + return } func callHandler(ctx context.Context, handler, arg, result interface{}) error { diff --git a/pkg/generic/generic_service_test.go b/pkg/generic/generic_service_test.go index 4b1f30ed9c..cc9bdb69d6 100644 --- a/pkg/generic/generic_service_test.go +++ b/pkg/generic/generic_service_test.go @@ -127,7 +127,7 @@ func TestServiceInfo(t *testing.T) { test.Assert(t, err == nil) g, err := JSONThriftGeneric(p) test.Assert(t, err == nil) - s = ServiceInfoWithCodec(g) + s = ServiceInfoWithGeneric(g) test.Assert(t, s.ServiceName == "Mock") } diff --git a/pkg/generic/map_test/generic_test.go b/pkg/generic/map_test/generic_test.go index 3ad0c45baf..d361cfac9d 100644 --- a/pkg/generic/map_test/generic_test.go +++ b/pkg/generic/map_test/generic_test.go @@ -19,6 +19,7 @@ package test import ( "context" "encoding/base64" + "fmt" "net" "reflect" "runtime" @@ -35,7 +36,9 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/server" + "github.com/cloudwego/kitex/transport" ) func TestThrift(t *testing.T) { @@ -384,6 +387,31 @@ func TestThrift2NormalServer(t *testing.T) { svr.Stop() } +// TODO(marina.sakai): remove this test when we remove the API generic.ServiceInfo() in v0.12.0 +func TestCompatible(t *testing.T) { + addr := test.GetLocalAddress() + svr := initThriftServer(t, addr, new(GenericServiceWithBase64Binary), true, false) + svcInfo := generic.ServiceInfo(serviceinfo.Thrift) + p, err := generic.NewThriftFileProvider("./idl/example.thrift") + test.Assert(t, err == nil) + g, err := generic.MapThriftGeneric(p) + test.Assert(t, err == nil) + var opts []client.Option + opts = append(opts, client.WithHostPorts(addr), client.WithTransportProtocol(transport.TTHeader)) + cli, err := genericclient.NewClientWithServiceInfo("destServiceName", g, svcInfo, opts...) + test.Assert(t, err == nil) + + reqMsg = map[string]interface{}{"BinaryMsg": []byte("hello")} + resp, err := cli.GenericCall(context.Background(), "ExampleMethod", reqMsg, callopt.WithRPCTimeout(100*time.Second)) + test.Assert(t, err == nil, err) + gr, ok := resp.(map[string]interface{}) + fmt.Println(gr) + test.Assert(t, ok) + test.Assert(t, reflect.DeepEqual(gr["BinaryMsg"], "hello")) + + svr.Stop() +} + func initThriftMockClient(t *testing.T, address string) genericclient.Client { p, err := generic.NewThriftFileProvider("./idl/mock.thrift") test.Assert(t, err == nil) diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 636598bc76..19adc3b24f 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -546,7 +546,7 @@ func TestInflightStreamClosing(t *testing.T) { <-timeout.C } case <-timeout.C: - t.Fatalf("Test timed out, expected a status error.") + t.Fatalf("%s", "Test timed out, expected a status error.") } } @@ -707,7 +707,7 @@ func TestLargeMessageWithDelayRead(t *testing.T) { select { case <-ready: case <-ctx.Done(): - t.Fatalf("Client timed out waiting for server handler to be initialized.") + t.Fatalf("%s", "Client timed out waiting for server handler to be initialized.") } server.mu.Lock() serviceHandler := server.h @@ -759,7 +759,7 @@ func TestLargeMessageWithDelayRead(t *testing.T) { select { case <-serviceHandler.notify: case <-ctx.Done(): - t.Fatalf("Client timed out") + t.Fatalf("%s", "Client timed out") } if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { t.Fatalf("s.Read(_) = _, %v, want _, ", err) @@ -896,7 +896,7 @@ func TestMaxStreams(t *testing.T) { for { select { case <-timer.C: - t.Fatalf("Test timeout: client didn't receive server settings.") + t.Fatalf("%s", "Test timeout: client didn't receive server settings.") default: } ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second)) @@ -928,7 +928,7 @@ func TestMaxStreams(t *testing.T) { } select { case <-done: - t.Fatalf("Test failed: didn't expect new stream to be created just yet.") + t.Fatalf("%s", "Test failed: didn't expect new stream to be created just yet.") default: } // Close the first stream created so that the new stream can finally be created. @@ -999,7 +999,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) } case <-time.After(3 * time.Second): - t.Fatalf("Failed to cancel the context of the sever side stream.") + t.Fatalf("%s", "Failed to cancel the context of the sever side stream.") } server.stop() } @@ -1244,7 +1244,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { for { select { case <-timer.C: - t.Fatalf("Test timed out.") + t.Fatalf("%s", "Test timed out.") case <-success: return default: @@ -1489,14 +1489,14 @@ func TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { defer cancel() s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) if err != nil { - t.Fatalf("failed to create the stream") + t.Fatalf("%s", "failed to create the stream") } timer := time.NewTimer(time.Second) defer timer.Stop() select { case <-s.headerChan: case <-timer.C: - t.Errorf("s.headerChan: got open, want closed") + t.Errorf("%s", "s.headerChan: got open, want closed") } } @@ -1711,7 +1711,7 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) { if wait { select { case <-timer.C: - t.Fatalf(err.Error()) + t.Fatalf("%s", err.Error()) default: time.Sleep(50 * time.Millisecond) continue @@ -1980,7 +1980,7 @@ func TestHeaderTblSize(t *testing.T) { continue } if i == 1000 { - t.Fatalf("unable to create any server transport after 10s") + t.Fatalf("%s", "unable to create any server transport after 10s") } for st := range server.conns { @@ -2007,7 +2007,7 @@ func TestHeaderTblSize(t *testing.T) { break } if i == 1000 { - t.Fatalf("expected len(limits) = 1 within 10s, got != 1") + t.Fatalf("%s", "expected len(limits) = 1 within 10s, got != 1") } ct.controlBuf.put(&outgoingSettings{ @@ -2030,7 +2030,7 @@ func TestHeaderTblSize(t *testing.T) { break } if i == 1000 { - t.Fatalf("expected len(limits) = 2 within 10s, got != 2") + t.Fatalf("%s", "expected len(limits) = 2 within 10s, got != 2") } } diff --git a/pkg/remote/trans/nphttp2/status/status_test.go b/pkg/remote/trans/nphttp2/status/status_test.go index 69ef976862..08cd55d82f 100644 --- a/pkg/remote/trans/nphttp2/status/status_test.go +++ b/pkg/remote/trans/nphttp2/status/status_test.go @@ -30,7 +30,7 @@ import ( func TestStatus(t *testing.T) { // test ok status statusMsg := "test" - statusOk := Newf(codes.OK, statusMsg) + statusOk := Newf(codes.OK, "%s", statusMsg) test.Assert(t, statusOk.Code() == codes.OK) test.Assert(t, statusOk.Message() == statusMsg) test.Assert(t, statusOk.Err() == nil) @@ -50,7 +50,7 @@ func TestStatus(t *testing.T) { test.Assert(t, emptyDetail == nil) // test error status - notFoundErr := Errorf(codes.NotFound, statusMsg) + notFoundErr := Errorf(codes.NotFound, "%s", statusMsg) statusErr, ok := FromError(notFoundErr) test.Assert(t, ok) test.Assert(t, statusErr.Code() == codes.NotFound) diff --git a/server/genericserver/server.go b/server/genericserver/server.go index 7823bf3c3e..bd16be08cb 100644 --- a/server/genericserver/server.go +++ b/server/genericserver/server.go @@ -25,7 +25,7 @@ import ( // NewServer creates a generic server with the given handler and options. func NewServer(handler generic.Service, g generic.Generic, opts ...server.Option) server.Server { - svcInfo := generic.ServiceInfoWithCodec(g) + svcInfo := generic.ServiceInfoWithGeneric(g) return NewServerWithServiceInfo(handler, g, svcInfo, opts...) } diff --git a/server/genericserver/server_test.go b/server/genericserver/server_test.go index 8c6b10fcea..6296429315 100644 --- a/server/genericserver/server_test.go +++ b/server/genericserver/server_test.go @@ -55,7 +55,7 @@ func TestNewServerWithServiceInfo(t *testing.T) { }) test.PanicAt(t, func() { - NewServerWithServiceInfo(nil, g, generic.ServiceInfoWithCodec(g)) + NewServerWithServiceInfo(nil, g, generic.ServiceInfoWithGeneric(g)) }, func(err interface{}) bool { if errMsg, ok := err.(error); ok { return strings.Contains(errMsg.Error(), "handler is nil.") From c9a8e4fbdc86499f5acc1e5912bbb3bd3ef3a8f4 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 15 Aug 2024 12:41:31 +0400 Subject: [PATCH 39/70] refactor: optimized apache codec without reflection (#1490) Co-authored-by: YangruiEmma --- go.mod | 2 +- go.sum | 4 +-- pkg/generic/thrift/base.go | 10 -------- pkg/protocol/bthrift/README.md | 7 +++--- pkg/protocol/bthrift/apache/apache.go | 34 +++++++++++++++++++++++++- pkg/remote/codec/thrift/thrift_data.go | 4 +-- 6 files changed, 42 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index 38bca68ffa..37a975733d 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.3.0 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8 + github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234 github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 124831aaf1..0634fd32dd 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6Fs github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= github.com/cloudwego/gopkg v0.1.0/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= -github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8 h1:kQPjddHw5Dufci/vfiRGMN3Uhx12XWqNpk1JdQ4Tjy0= -github.com/cloudwego/gopkg v0.1.1-0.20240806070559-b36f09467ae8/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234 h1:ID8y5ks8EetjB7Qqqgfd9FhPIpt+PSUsQxHeZA64+7Q= +github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go index ae930c3307..34dfb80ad0 100644 --- a/pkg/generic/thrift/base.go +++ b/pkg/generic/thrift/base.go @@ -18,16 +18,6 @@ package thrift import "github.com/cloudwego/gopkg/protocol/thrift/base" -// TrafficEnv ... -// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base -type TrafficEnv = base.TrafficEnv - -// NewTrafficEnv ... -// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base -func NewTrafficEnv() *TrafficEnv { - return base.NewTrafficEnv() -} - // Base ... // Deprecated: use github.com/cloudwego/gopkg/protocol/thrift/base type Base = base.Base diff --git a/pkg/protocol/bthrift/README.md b/pkg/protocol/bthrift/README.md index d90609fab2..73e8a9ec75 100644 --- a/pkg/protocol/bthrift/README.md +++ b/pkg/protocol/bthrift/README.md @@ -6,7 +6,7 @@ We're planning to get rid of `github.com/apache/thrift`, here are steps we have 1. Removed unnecessary dependencies of apache from kitex 2. Moved all apache dependencies to `bthrift/apache`, mainly types, interfaces and consts - We may use type alias at the beginning for better compatibility - - `bthrift/apache`calls `apache.RegisterNewTBinaryProtocol` in `gopkg` for step 4 + - `bthrift/apache`calls `apache.RegisterXXX` in `gopkg` for step 4 3. For internal dependencies of apache, use `gopkg` 4. For existing framework code working with apache thrift: - Use `gopkg/protocol/thrift/apache` @@ -15,10 +15,11 @@ We're planning to get rid of `github.com/apache/thrift`, here are steps we have - replace `github.com/apache/thrift` with `bthrift/apache` - by using `thrift_import_path` parameter of thriftgo -The final step we planned to do in version v0.12.0: +The final step we planned to do in kitex version v0.12.0: * Add go.mod for `bthrift` -* Remove the last `github.com/apache/thrift` dependencies +* Remove the last `github.com/apache/thrift` dependencies (Remove the import 'github.com/cloudwego/kitex/pkg/protocol/bthrift' and 'github.com/cloudwego/kitex/pkg/protocol/bthrift/apache') * `ThriftMessageCodec` of `pkg/utils` * `MessageReader` and `MessageWriter` interfaces in `pkg/remote/codec/thrift` * `BinaryProtocol` type in `pkg/remote/codec/thrift` * basic codec tests in `pkg/remote/codec/thrift` + * kitex DOESN'T DEPEND bthrift, bthrift will only be dependent in the generated code(if has apache thrift code) diff --git a/pkg/protocol/bthrift/apache/apache.go b/pkg/protocol/bthrift/apache/apache.go index 4537fa1b57..f4c26e9c77 100644 --- a/pkg/protocol/bthrift/apache/apache.go +++ b/pkg/protocol/bthrift/apache/apache.go @@ -17,11 +17,43 @@ package apache import ( + "errors" + "github.com/apache/thrift/lib/go/thrift" "github.com/cloudwego/gopkg/protocol/thrift/apache" ) func init() { // it makes github.com/cloudwego/gopkg/protocol/thrift/apache works - _ = apache.RegisterNewTBinaryProtocol(thrift.NewTBinaryProtocol) + apache.RegisterCheckTStruct(checkTStruct) + apache.RegisterThriftRead(callThriftRead) + apache.RegisterThriftWrite(callThriftWrite) +} + +var errNotThriftTStruct = errors.New("not thrift.TStruct") + +func checkTStruct(v interface{}) error { + _, ok := v.(thrift.TStruct) + if !ok { + return errNotThriftTStruct + } + return nil +} + +func callThriftRead(t apache.TTransport, v interface{}) error { + p, ok := v.(thrift.TStruct) + if !ok { + return errNotThriftTStruct + } + in := thrift.NewTBinaryProtocol(t, true, true) + return p.Read(in) +} + +func callThriftWrite(t apache.TTransport, v interface{}) error { + p, ok := v.(thrift.TStruct) + if !ok { + return errNotThriftTStruct + } + out := thrift.NewTBinaryProtocol(t, true, true) + return p.Write(out) } diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 743c713a9e..e2732f7b5e 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -80,7 +80,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ // verifyMarshalBasicThriftDataType verifies whether data could be marshaled by old thrift way func verifyMarshalBasicThriftDataType(data interface{}) error { - if err := apache.CheckThriftWrite(data); err != nil { + if err := apache.CheckTStruct(data); err != nil { return errEncodeMismatchMsgType } return nil @@ -200,7 +200,7 @@ func (c thriftCodec) hyperUnmarshal(trans remote.ByteBuffer, data interface{}, d // verifyUnmarshalBasicThriftDataType verifies whether data could be unmarshal by old thrift way func verifyUnmarshalBasicThriftDataType(data interface{}) error { - if err := apache.CheckThriftRead(data); err != nil { + if err := apache.CheckTStruct(data); err != nil { return errDecodeMismatchMsgType } return nil From e4c77162ed619566312751c5ab82c50a7bd13505 Mon Sep 17 00:00:00 2001 From: Guangming Luo Date: Fri, 16 Aug 2024 09:57:18 +0800 Subject: [PATCH 40/70] chore: update CI and go.mod to support 1.18-1.23 (#1493) --- .github/workflows/tests.yml | 8 ++-- go.mod | 8 ++-- go.sum | 42 +++---------------- internal/generic/proto/json_test.go | 4 +- pkg/generic/httppb_test/generic_test.go | 4 +- pkg/generic/httpthrift_codec.go | 3 +- pkg/remote/bound/limiter_inbound_test.go | 16 +++---- .../codec/protobuf/encoding/gzip/gzip.go | 5 +-- .../trans/netpoll/http_client_handler.go | 9 ++-- pkg/utils/yaml.go | 4 +- tool/cmd/kitex/utils/utils.go | 3 +- .../internal_pkg/generator/custom_template.go | 5 +-- tool/internal_pkg/generator/template.go | 10 ++--- .../pluginmode/thriftgo/convertor.go | 5 +-- .../pluginmode/thriftgo/hessian2.go | 13 +++--- .../pluginmode/thriftgo/patcher.go | 6 +-- tool/internal_pkg/util/dump.go | 8 ++-- tool/internal_pkg/util/util.go | 10 ++--- 18 files changed, 63 insertions(+), 100 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9f2d820fbb..481f69f6af 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.22" + go-version: stable - name: Benchmark # we only use this CI to verify bench code works # setting benchtime=100ms is saving our time... @@ -36,7 +36,7 @@ jobs: compatibility-test: strategy: matrix: - go: [ "1.17", "1.18", "1.19", "1.20", "1.21", "1.22" ] + go: [ "1.18", "1.19", "1.20", "1.21", "1.22", "1.23" ] os: [ X64, ARM64 ] runs-on: ${{ matrix.os }} steps: @@ -56,7 +56,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.22" + go-version: stable - name: Prepare run: | go install github.com/cloudwego/thriftgo@main @@ -84,6 +84,6 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.22" + go-version: stable - name: Windows compatibility test run: go test -run=^$ ./... diff --git a/go.mod b/go.mod index 37a975733d..96f719dad6 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,11 @@ module github.com/cloudwego/kitex -go 1.17 +go 1.18 require ( github.com/apache/thrift v0.13.0 github.com/bytedance/gopkg v0.1.0 - github.com/bytedance/sonic v1.12.0 + github.com/bytedance/sonic v1.12.1 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.3.0 github.com/cloudwego/fastpb v0.0.4 @@ -16,10 +16,10 @@ require ( github.com/cloudwego/runtimex v0.1.0 github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 github.com/golang/mock v1.6.0 - github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 + github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 github.com/jhump/protoreflect v1.8.2 github.com/json-iterator/go v1.1.12 - github.com/tidwall/gjson v1.9.3 + github.com/tidwall/gjson v1.17.3 golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.13.0 diff --git a/go.sum b/go.sum index 0634fd32dd..6f023c2247 100644 --- a/go.sum +++ b/go.sum @@ -4,19 +4,15 @@ github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.0.0-20240711085056-a03554c296f8/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic v1.12.0 h1:YGPgxF9xzaCNvd/ZKdQ28yRovhfMFZQjuk6fKBzZ3ls= -github.com/bytedance/sonic v1.12.0/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic v1.12.1 h1:jWl5Qz1fy7X1ioY74WqO0KjAMtAGQs4sYnjiEBiyX24= +github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= -github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= -github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= @@ -28,7 +24,6 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.0/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234 h1:ID8y5ks8EetjB7Qqqgfd9FhPIpt+PSUsQxHeZA64+7Q= github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= @@ -39,16 +34,13 @@ github.com/cloudwego/netpoll v0.6.3 h1:t+ndlwBFjQZimUj3ul31DwI45t18eOr2pcK3juZZm github.com/cloudwego/netpoll v0.6.3/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM= github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= -github.com/cloudwego/thriftgo v0.3.6/go.mod h1:29ukiySoAMd0vXMYIduAY9dph/7dmChvOS11YLotFb8= github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 h1:KiEGBvsyAyUrFrpEi/e77K0SWTLK8FMHhSQ5c9kFJic= github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083/go.mod h1:R4a+4aVDI0V9YCTfpNgmvbkq/9ThKgF7Om8Z0I36698= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -57,7 +49,6 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -82,19 +73,17 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 h1:mpL/HvfIgIejhVwAfxBQkwEjlhP5o0O9RAeTAjpwzxc= -github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3/go.mod h1:gSuNB+gJaOiQKLEZ+q+PK9Mq3SOzhRcw2GsGS/FhYDk= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHLwW0= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= -github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4ua0= github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= @@ -105,7 +94,6 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 h1:uiS4zKYKJVj5F3ID+5iylfKPsEQmBEOucSD9Vgmn0i0= @@ -120,19 +108,16 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tidwall/gjson v1.9.3 h1:hqzS9wAHMO+KVBBkLxYdkEeeFHuqr95GfClRLKlgK0E= -github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -152,7 +137,6 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -164,7 +148,6 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -179,8 +162,6 @@ golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLd golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -206,27 +187,19 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -238,9 +211,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= @@ -276,7 +247,6 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/generic/proto/json_test.go b/internal/generic/proto/json_test.go index 1bec69eb12..0f0fcbbe2b 100644 --- a/internal/generic/proto/json_test.go +++ b/internal/generic/proto/json_test.go @@ -19,7 +19,7 @@ package proto import ( "context" "encoding/json" - "io/ioutil" + "os" "reflect" "testing" @@ -138,7 +138,7 @@ func getExampleReq() string { // read ProtoBuf's data in binary format from exampleProtoPath func readExampleReqProtoBufData() []byte { - out, err := ioutil.ReadFile(example2ProtoPath) + out, err := os.ReadFile(example2ProtoPath) if err != nil { panic(err) } diff --git a/pkg/generic/httppb_test/generic_test.go b/pkg/generic/httppb_test/generic_test.go index 099352137f..08f6a356e3 100644 --- a/pkg/generic/httppb_test/generic_test.go +++ b/pkg/generic/httppb_test/generic_test.go @@ -19,7 +19,7 @@ package test import ( "bytes" "context" - "io/ioutil" + "io" "net" "net/http" "os" @@ -47,7 +47,7 @@ func initThriftClientByIDL(t *testing.T, addr, idl, pbIdl string) genericclient. test.Assert(t, err == nil) pbf, err := os.Open(pbIdl) test.Assert(t, err == nil) - pbContent, err := ioutil.ReadAll(pbf) + pbContent, err := io.ReadAll(pbf) test.Assert(t, err == nil) pbf.Close() pbp, err := generic.NewPbContentProvider(pbIdl, map[string]string{pbIdl: string(pbContent)}) diff --git a/pkg/generic/httpthrift_codec.go b/pkg/generic/httpthrift_codec.go index 85ca2d4128..3486a41cf5 100644 --- a/pkg/generic/httpthrift_codec.go +++ b/pkg/generic/httpthrift_codec.go @@ -20,7 +20,6 @@ import ( "context" "errors" "io" - "io/ioutil" "net/http" "sync/atomic" @@ -188,7 +187,7 @@ func FromHTTPRequest(req *http.Request) (*HTTPRequest, error) { // body == nil if from Get request return customReq, nil } - if customReq.RawBody, err = ioutil.ReadAll(b); err != nil { + if customReq.RawBody, err = io.ReadAll(b); err != nil { return nil, err } return customReq, nil diff --git a/pkg/remote/bound/limiter_inbound_test.go b/pkg/remote/bound/limiter_inbound_test.go index 71eea396d8..86347d50ba 100644 --- a/pkg/remote/bound/limiter_inbound_test.go +++ b/pkg/remote/bound/limiter_inbound_test.go @@ -85,12 +85,12 @@ func TestLimiterOnActive(t *testing.T) { handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, nil, false) ctx, err := handler.OnActive(ctx, invoke.NewMessage(nil, nil)) test.Assert(t, ctx != nil) - test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrConnOverLimit)) muxHandler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, nil, true) ctx, err = muxHandler.OnActive(ctx, invoke.NewMessage(nil, nil)) test.Assert(t, ctx != nil) - test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrConnOverLimit)) }) t.Run("Test OnActive with limit acquire false and non-nil reporter", func(t *testing.T) { @@ -109,13 +109,13 @@ func TestLimiterOnActive(t *testing.T) { ctx, err := handler.OnActive(ctx, invoke.NewMessage(nil, nil)) test.Assert(t, ctx != nil) test.Assert(t, err != nil) - test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrConnOverLimit)) muxHandler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true) ctx, err = muxHandler.OnActive(ctx, invoke.NewMessage(nil, nil)) test.Assert(t, ctx != nil) test.Assert(t, err != nil) - test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrConnOverLimit)) }) } @@ -154,7 +154,7 @@ func TestLimiterOnRead(t *testing.T) { ctx, err := handler.OnRead(ctx, invoke.NewMessage(nil, nil)) test.Assert(t, ctx != nil) - test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrQPSOverLimit)) }) t.Run("Test OnRead with limit acquire false and non-nil reporter", func(t *testing.T) { @@ -174,7 +174,7 @@ func TestLimiterOnRead(t *testing.T) { test.Assert(t, ctx != nil) test.Assert(t, err != nil) - test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrQPSOverLimit)) }) } @@ -237,7 +237,7 @@ func TestLimiterOnMessage(t *testing.T) { ctx, err := handler.OnMessage(ctx, req, remote.NewMessage(nil, nil, nil, remote.Reply, remote.Client)) test.Assert(t, ctx != nil) - test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrQPSOverLimit)) }) t.Run("Test OnMessage with limit acquire false and non-nil reporter", func(t *testing.T) { @@ -258,6 +258,6 @@ func TestLimiterOnMessage(t *testing.T) { test.Assert(t, ctx != nil) test.Assert(t, err != nil) - test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err)) + test.Assert(t, errors.Is(err, kerrors.ErrQPSOverLimit)) }) } diff --git a/pkg/remote/codec/protobuf/encoding/gzip/gzip.go b/pkg/remote/codec/protobuf/encoding/gzip/gzip.go index 1a63a37d52..08e58ce3c6 100644 --- a/pkg/remote/codec/protobuf/encoding/gzip/gzip.go +++ b/pkg/remote/codec/protobuf/encoding/gzip/gzip.go @@ -32,7 +32,6 @@ import ( "encoding/binary" "fmt" "io" - "io/ioutil" "sync" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" @@ -44,7 +43,7 @@ const Name = "gzip" func init() { c := &compressor{} c.poolCompressor.New = func() interface{} { - return &writer{Writer: gzip.NewWriter(ioutil.Discard), pool: &c.poolCompressor} + return &writer{Writer: gzip.NewWriter(io.Discard), pool: &c.poolCompressor} } encoding.RegisterCompressor(c) } @@ -65,7 +64,7 @@ func SetLevel(level int) error { } c := encoding.GetCompressor(Name).(*compressor) c.poolCompressor.New = func() interface{} { - w, err := gzip.NewWriterLevel(ioutil.Discard, level) + w, err := gzip.NewWriterLevel(io.Discard, level) if err != nil { panic(err) } diff --git a/pkg/remote/trans/netpoll/http_client_handler.go b/pkg/remote/trans/netpoll/http_client_handler.go index f4d03c12a4..63bde718bf 100644 --- a/pkg/remote/trans/netpoll/http_client_handler.go +++ b/pkg/remote/trans/netpoll/http_client_handler.go @@ -21,7 +21,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "net" "net/http" "path" @@ -139,10 +139,9 @@ func (t *httpCliTransHandler) OnInactive(ctx context.Context, conn net.Conn) { // OnError implements the remote.ClientTransHandler interface. // This is called when panic happens. func (t *httpCliTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { - if pe, ok := err.(*kerrors.DetailedError); ok { + var pe *kerrors.DetailedError + if errors.As(err, &pe) { klog.CtxErrorf(ctx, "KITEX: send http request error, remote=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), pe.Stack()) - } else { - klog.CtxErrorf(ctx, "KITEX: send http request error, remote=%s, error=%s", conn.RemoteAddr(), err.Error()) } } @@ -278,7 +277,7 @@ func getBodyBufReader(buf remote.ByteBuffer) (remote.ByteBuffer, error) { if hr.StatusCode != http.StatusOK { return nil, fmt.Errorf("http response not OK, StatusCode: %d", hr.StatusCode) } - b, err := ioutil.ReadAll(hr.Body) + b, err := io.ReadAll(hr.Body) hr.Body.Close() if err != nil { return nil, fmt.Errorf("read http response body error:%w", err) diff --git a/pkg/utils/yaml.go b/pkg/utils/yaml.go index 61946128b1..16f8b80deb 100644 --- a/pkg/utils/yaml.go +++ b/pkg/utils/yaml.go @@ -17,7 +17,7 @@ package utils import ( - "io/ioutil" + "io" "os" "time" @@ -45,7 +45,7 @@ func ReadYamlConfigFile(yamlFile string) (*YamlConfig, error) { } defer fd.Close() - b, err := ioutil.ReadAll(fd) + b, err := io.ReadAll(fd) if err != nil { return nil, err } diff --git a/tool/cmd/kitex/utils/utils.go b/tool/cmd/kitex/utils/utils.go index cfe1399a46..2c763a1205 100644 --- a/tool/cmd/kitex/utils/utils.go +++ b/tool/cmd/kitex/utils/utils.go @@ -15,7 +15,6 @@ package utils import ( - "io/ioutil" "os" "os/exec" "strings" @@ -55,7 +54,7 @@ func OnKitexToolNormalExit(args kargs.Arguments) { func DeleteKitexYaml() { // try to read kitex.yaml - data, err := ioutil.ReadFile("kitex.yaml") + data, err := os.ReadFile("kitex.yaml") if err != nil { if !os.IsNotExist(err) { log.Warn("kitex.yaml, which is used to record tool info, is deprecated, it's renamed as kitex_info.yaml, you can delete it or ignore it.") diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index e8a74dbe64..3f364589db 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -16,7 +16,6 @@ package generator import ( "fmt" - "io/ioutil" "os" "path" "path/filepath" @@ -214,13 +213,13 @@ func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, } func readTemplates(dir string) ([]*Template, error) { - files, _ := ioutil.ReadDir(dir) + files, _ := os.ReadDir(dir) var ts []*Template for _, f := range files { // filter dir and non-yaml files if f.Name() != ExtensionFilename && !f.IsDir() && (strings.HasSuffix(f.Name(), "yaml") || strings.HasSuffix(f.Name(), "yml")) { p := filepath.Join(dir, f.Name()) - tplData, err := ioutil.ReadFile(p) + tplData, err := os.ReadFile(p) if err != nil { return nil, fmt.Errorf("read layout config from %s failed, err: %v", p, err.Error()) } diff --git a/tool/internal_pkg/generator/template.go b/tool/internal_pkg/generator/template.go index db347903d2..7d43961923 100644 --- a/tool/internal_pkg/generator/template.go +++ b/tool/internal_pkg/generator/template.go @@ -17,7 +17,7 @@ package generator import ( "encoding/json" "fmt" - "io/ioutil" + "os" "gopkg.in/yaml.v3" ) @@ -63,7 +63,7 @@ func (p *TemplateExtension) FromJSONFile(filename string) error { if p == nil { return nil } - data, err := ioutil.ReadFile(filename) + data, err := os.ReadFile(filename) if err != nil { return err } @@ -76,7 +76,7 @@ func (p *TemplateExtension) ToJSONFile(filename string) error { if err != nil { return err } - return ioutil.WriteFile(filename, data, 0o644) + return os.WriteFile(filename, data, 0o644) } // FromYAMLFile unmarshals a TemplateExtension with YAML format from the given file. @@ -84,7 +84,7 @@ func (p *TemplateExtension) FromYAMLFile(filename string) error { if p == nil { return nil } - data, err := ioutil.ReadFile(filename) + data, err := os.ReadFile(filename) if err != nil { return err } @@ -96,7 +96,7 @@ func (p *TemplateExtension) ToYAMLFile(filename string) error { if err != nil { return err } - return ioutil.WriteFile(filename, data, 0o644) + return os.WriteFile(filename, data, 0o644) } func (p *TemplateExtension) Merge(other *TemplateExtension) { diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index 475a5fad09..7b51e67f83 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -18,7 +18,6 @@ import ( "fmt" "go/format" "io" - "io/ioutil" "os" "path/filepath" "regexp" @@ -56,7 +55,7 @@ func (c *converter) init(req *plugin.Request) error { return fmt.Errorf("expect language to be 'go'. Encountered '%s'", req.Language) } - // resotre the arguments for kitex + // restore the arguments for kitex if err := c.Config.Unpack(req.PluginParameters); err != nil { return err } @@ -510,7 +509,7 @@ func (c *converter) persist(res *plugin.Response) error { if err := os.MkdirAll(path, 0o755); err != nil && !os.IsExist(err) { return fmt.Errorf("failed to create path '%s': %w", path, err) } - if err := ioutil.WriteFile(full, content, 0o644); err != nil { + if err := os.WriteFile(full, content, 0o644); err != nil { return fmt.Errorf("failed to write file '%s': %w", full, err) } } diff --git a/tool/internal_pkg/pluginmode/thriftgo/hessian2.go b/tool/internal_pkg/pluginmode/thriftgo/hessian2.go index b6005c0024..1eaffa7b98 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/hessian2.go +++ b/tool/internal_pkg/pluginmode/thriftgo/hessian2.go @@ -17,7 +17,6 @@ package thriftgo import ( "fmt" "io" - "io/ioutil" "os" "path/filepath" "regexp" @@ -137,7 +136,7 @@ func patchIDLRefConfig(cfg *generator.Config) error { // loadIDLRefConfig load idl-ref config from file object func loadIDLRefConfig(fileName string, reader io.Reader) (*config.RawConfig, error) { - data, err := ioutil.ReadAll(reader) + data, err := io.ReadAll(reader) if err != nil { return nil, fmt.Errorf("read %s file failed: %s", fileName, err.Error()) } @@ -166,7 +165,7 @@ var ( func Hessian2PatchByReplace(args generator.Config, subDirPath string) error { output := args.OutputPath newPath := util.JoinPath(output, args.GenPath, subDirPath) - fs, err := ioutil.ReadDir(newPath) + fs, err := os.ReadDir(newPath) if err != nil { return err } @@ -179,7 +178,7 @@ func Hessian2PatchByReplace(args generator.Config, subDirPath string) error { return err } } else if strings.HasSuffix(f.Name(), ".go") { - data, err := ioutil.ReadFile(fileName) + data, err := os.ReadFile(fileName) if err != nil { return err } @@ -187,7 +186,7 @@ func Hessian2PatchByReplace(args generator.Config, subDirPath string) error { data = replaceJavaObject(data) data = replaceJavaException(data) data = replaceJavaExceptionEmptyVerification(data) - if err = ioutil.WriteFile(fileName, data, 0o644); err != nil { + if err = os.WriteFile(fileName, data, 0o644); err != nil { return err } } @@ -199,13 +198,13 @@ func Hessian2PatchByReplace(args generator.Config, subDirPath string) error { } handlerName := util.JoinPath(output, "handler.go") - handler, err := ioutil.ReadFile(handlerName) + handler, err := os.ReadFile(handlerName) if err != nil { return err } handler = replaceJavaObject(handler) - return ioutil.WriteFile(handlerName, handler, 0o644) + return os.WriteFile(handlerName, handler, 0o644) } func replaceJavaObject(content []byte) []byte { diff --git a/tool/internal_pkg/pluginmode/thriftgo/patcher.go b/tool/internal_pkg/pluginmode/thriftgo/patcher.go index ea9e263b9d..9b83d99529 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/patcher.go +++ b/tool/internal_pkg/pluginmode/thriftgo/patcher.go @@ -16,7 +16,7 @@ package thriftgo import ( "fmt" - "io/ioutil" + "os" "path/filepath" "reflect" "runtime" @@ -326,7 +326,7 @@ func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err e // fd.WriteString("content: " + content + "\nend\n") if p.copyIDL { - content, err := ioutil.ReadFile(ast.Filename) + content, err := os.ReadFile(ast.Filename) if err != nil { return nil, fmt.Errorf("read %q: %w", ast.Filename, err) } @@ -406,7 +406,7 @@ func (p *patcher) extractLocalLibs(imports []util.Import) []util.Import { // DoRecord records current cmd into kitex-all.sh func doRecord(recordCmd []string) string { - bytes, err := ioutil.ReadFile(getBashPath()) + bytes, err := os.ReadFile(getBashPath()) content := string(bytes) if err != nil { content = "#! /usr/bin/env bash\n" diff --git a/tool/internal_pkg/util/dump.go b/tool/internal_pkg/util/dump.go index 08ff22175f..cda5d78433 100644 --- a/tool/internal_pkg/util/dump.go +++ b/tool/internal_pkg/util/dump.go @@ -15,7 +15,7 @@ package util import ( - "io/ioutil" + "io" "os" ) @@ -25,12 +25,12 @@ import ( // This feature makes it easier to debug kitex tool. func ReadInput() ([]byte, error) { if dumpFileName := os.Getenv("KITEX_TOOL_STDIN_LOAD_FILE"); dumpFileName != "" { - return ioutil.ReadFile(dumpFileName) + return os.ReadFile(dumpFileName) } - data, err := ioutil.ReadAll(os.Stdin) + data, err := io.ReadAll(os.Stdin) if err == nil { if dumpFileName := os.Getenv("KITEX_TOOL_STDIN_DUMP_FILE"); dumpFileName != "" { - ioutil.WriteFile(dumpFileName, data, 0o644) + os.WriteFile(dumpFileName, data, 0o644) } } return data, err diff --git a/tool/internal_pkg/util/util.go b/tool/internal_pkg/util/util.go index e75255dec7..196f5b29bc 100644 --- a/tool/internal_pkg/util/util.go +++ b/tool/internal_pkg/util/util.go @@ -18,7 +18,7 @@ import ( "fmt" "go/build" "go/format" - "io/ioutil" + "io" "net/http" "os" "os/exec" @@ -125,7 +125,7 @@ func UpperFirst(s string) string { return string(rs) } -// NotPtr converts an pointer type into non-pointer type. +// NotPtr converts a pointer type into non-pointer type. func NotPtr(s string) string { return strings.ReplaceAll(s, "*", "") } @@ -135,7 +135,7 @@ func NotPtr(s string) string { func SearchGoMod(cwd string) (moduleName, path string, found bool) { for { path = filepath.Join(cwd, "go.mod") - data, err := ioutil.ReadFile(path) + data, err := os.ReadFile(path) if err == nil { re := regexp.MustCompile(`^\s*module\s+(\S+)\s*`) for _, line := range strings.Split(string(data), "\n") { @@ -261,11 +261,11 @@ func DownloadFile(remotePath, localPath string) error { return fmt.Errorf("failed to download file, http status: %s", resp.Status) } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return err } - err = ioutil.WriteFile(localPath, body, 0o644) + err = os.WriteFile(localPath, body, 0o644) if err != nil { return err } From 6609f3b2ae18d5d09c6f39371ba30d9dd1712b50 Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Thu, 15 Aug 2024 09:57:25 +0800 Subject: [PATCH 41/70] feat: add GetCallee to kitexutil --- pkg/utils/kitexutil/kitexutil.go | 12 ++++++++++++ pkg/utils/kitexutil/kitexutil_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/pkg/utils/kitexutil/kitexutil.go b/pkg/utils/kitexutil/kitexutil.go index 92d77d7347..b8fba7e3a5 100644 --- a/pkg/utils/kitexutil/kitexutil.go +++ b/pkg/utils/kitexutil/kitexutil.go @@ -37,6 +37,18 @@ func GetCaller(ctx context.Context) (string, bool) { return ri.From().ServiceName(), true } +// GetCallee is used to get the Service Name of the callee. +// Return false if failed to get the information. +func GetCallee(ctx context.Context) (string, bool) { + defer func() { recover() }() + + ri := rpcinfo.GetRPCInfo(ctx) + if ri == nil { + return "", false + } + return ri.To().ServiceName(), true +} + // GetMethod is used to get the current RPC Method name. // Return false if failed to get the information. func GetMethod(ctx context.Context) (string, bool) { diff --git a/pkg/utils/kitexutil/kitexutil_test.go b/pkg/utils/kitexutil/kitexutil_test.go index 6766aafcad..58b9b858d0 100644 --- a/pkg/utils/kitexutil/kitexutil_test.go +++ b/pkg/utils/kitexutil/kitexutil_test.go @@ -80,6 +80,33 @@ func TestGetCaller(t *testing.T) { } } +func TestGetCallee(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want string + want1 bool + }{ + {name: "Success", args: args{testCtx}, want: callee, want1: true}, + {name: "Failure", args: args{context.Background()}, want: "", want1: false}, + {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := GetCallee(tt.args.ctx) + if got != tt.want { + t.Errorf("GetCallee() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("GetCallee() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + func TestGetCallerAddr(t *testing.T) { type args struct { ctx context.Context From a248631a4c1a5504e9296db2c219a991df2b2a8d Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Fri, 16 Aug 2024 14:45:11 +0800 Subject: [PATCH 42/70] test: ignore SA1019 --- pkg/generic/map_test/generic_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/generic/map_test/generic_test.go b/pkg/generic/map_test/generic_test.go index d361cfac9d..7b4b1cb332 100644 --- a/pkg/generic/map_test/generic_test.go +++ b/pkg/generic/map_test/generic_test.go @@ -391,6 +391,7 @@ func TestThrift2NormalServer(t *testing.T) { func TestCompatible(t *testing.T) { addr := test.GetLocalAddress() svr := initThriftServer(t, addr, new(GenericServiceWithBase64Binary), true, false) + //lint:ignore SA1019 we will remove this later svcInfo := generic.ServiceInfo(serviceinfo.Thrift) p, err := generic.NewThriftFileProvider("./idl/example.thrift") test.Assert(t, err == nil) From efb8f953e8e111c613e14988d056e78078c4faaf Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Fri, 16 Aug 2024 17:08:35 +0800 Subject: [PATCH 43/70] perf(thrift): use kitex BinaryProtocol replace apache BinaryProtocol for apache thrift codec (#1495) --- go.mod | 2 +- go.sum | 4 +- pkg/protocol/bthrift/apache/apache.go | 19 +- .../bthrift/apache/binary_protocol.go | 562 ++++++++++++++++++ pkg/remote/codec/thrift/deprecated.go | 498 +--------------- pkg/remote/codec/thrift/thrift_data.go | 4 +- pkg/utils/thrift.go | 8 +- pkg/utils/thrift_test.go | 2 +- 8 files changed, 590 insertions(+), 509 deletions(-) create mode 100644 pkg/protocol/bthrift/apache/binary_protocol.go diff --git a/go.mod b/go.mod index 96f719dad6..77a5258b4b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.3.0 github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234 + github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 6f023c2247..df310cee6e 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234 h1:ID8y5ks8EetjB7Qqqgfd9FhPIpt+PSUsQxHeZA64+7Q= -github.com/cloudwego/gopkg v0.1.1-0.20240812141034-843ef58f1234/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d h1:QBV/89XA0Mwlk6LQgLIDIf1vDMWSn9O2Xx1lJX7PRGI= +github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= diff --git a/pkg/protocol/bthrift/apache/apache.go b/pkg/protocol/bthrift/apache/apache.go index f4c26e9c77..d227f95858 100644 --- a/pkg/protocol/bthrift/apache/apache.go +++ b/pkg/protocol/bthrift/apache/apache.go @@ -18,6 +18,7 @@ package apache import ( "errors" + "io" "github.com/apache/thrift/lib/go/thrift" "github.com/cloudwego/gopkg/protocol/thrift/apache" @@ -40,20 +41,30 @@ func checkTStruct(v interface{}) error { return nil } -func callThriftRead(t apache.TTransport, v interface{}) error { +func callThriftRead(r io.ReadWriter, v interface{}) error { p, ok := v.(thrift.TStruct) if !ok { return errNotThriftTStruct } - in := thrift.NewTBinaryProtocol(t, true, true) + t, ok := r.(byteBuffer) + if ok { + in := NewBinaryProtocol(t) + return p.Read(in) + } + in := thrift.NewTBinaryProtocol(apache.NewDefaultTransport(r), true, true) return p.Read(in) } -func callThriftWrite(t apache.TTransport, v interface{}) error { +func callThriftWrite(w io.ReadWriter, v interface{}) error { p, ok := v.(thrift.TStruct) if !ok { return errNotThriftTStruct } - out := thrift.NewTBinaryProtocol(t, true, true) + t, ok := w.(byteBuffer) + if ok { + out := NewBinaryProtocol(t) + return p.Write(out) + } + out := thrift.NewTBinaryProtocol(apache.NewDefaultTransport(w), true, true) return p.Write(out) } diff --git a/pkg/protocol/bthrift/apache/binary_protocol.go b/pkg/protocol/bthrift/apache/binary_protocol.go new file mode 100644 index 0000000000..8a150c0dc4 --- /dev/null +++ b/pkg/protocol/bthrift/apache/binary_protocol.go @@ -0,0 +1,562 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import ( + "context" + "encoding/binary" + "io" + "math" + "sync" + + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" +) + +/* + BinaryProtocol implementation was moved from cloudwego/kitex/pkg/remote/codec/thrift/binary_protocol.go +*/ + +var ( + _ TProtocol = (*BinaryProtocol)(nil) + + bpPool = sync.Pool{ + New: func() interface{} { + return &BinaryProtocol{} + }, + } +) + +// byteBuffer is sub interfaces of remote.ByteBuffer +// the repeated definition here is to avoid dependency on remote packages +type byteBuffer interface { + io.ReadWriter + + // WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte. + WriteString(s string) (n int, err error) + + // WriteBinary writes the []byte directly. Callers must guarantee that the []byte doesn't change. + WriteBinary(b []byte) (n int, err error) + + // Malloc n bytes sequentially in the writer buffer. + Malloc(n int) (buf []byte, err error) + + // Next reads the next n bytes sequentially and returns the original buffer. + Next(n int) (p []byte, err error) + + // ReadString is a more efficient way to read string than Next. + ReadString(n int) (s string, err error) + + // ReadBinary like ReadString. + // Returns a copy of original buffer. + ReadBinary(n int) (p []byte, err error) + + // ReadableLen returns the total length of readable buffer. + // Return: -1 means unreadable. + ReadableLen() (n int) + + // Flush writes any malloc data to the underlying io.Writer. + // The malloced buffer must be set correctly. + Flush() (err error) +} + +// BinaryProtocol was moved from cloudwego/kitex/pkg/remote/codec/thrift +// Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol +type BinaryProtocol struct { + trans byteBuffer +} + +// NewBinaryProtocol ... +// Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol +func NewBinaryProtocol(t byteBuffer) *BinaryProtocol { + bp := bpPool.Get().(*BinaryProtocol) + bp.trans = t + return bp +} + +// Recycle ... +func (p *BinaryProtocol) Recycle() { + p.trans = nil + bpPool.Put(p) +} + +/** + * Writing Methods + */ + +// WriteMessageBegin ... +func (p *BinaryProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error { + version := uint32(VERSION_1) | uint32(typeID) + e := p.WriteI32(int32(version)) + if e != nil { + return e + } + e = p.WriteString(name) + if e != nil { + return e + } + e = p.WriteI32(seqID) + return e +} + +// WriteMessageEnd ... +func (p *BinaryProtocol) WriteMessageEnd() error { + return nil +} + +// WriteStructBegin ... +func (p *BinaryProtocol) WriteStructBegin(name string) error { + return nil +} + +// WriteStructEnd ... +func (p *BinaryProtocol) WriteStructEnd() error { + return nil +} + +// WriteFieldBegin ... +func (p *BinaryProtocol) WriteFieldBegin(name string, typeID TType, id int16) error { + e := p.WriteByte(int8(typeID)) + if e != nil { + return e + } + e = p.WriteI16(id) + return e +} + +// WriteFieldEnd ... +func (p *BinaryProtocol) WriteFieldEnd() error { + return nil +} + +// WriteFieldStop ... +func (p *BinaryProtocol) WriteFieldStop() error { + e := p.WriteByte(STOP) + return e +} + +// WriteMapBegin ... +func (p *BinaryProtocol) WriteMapBegin(keyType, valueType TType, size int) error { + e := p.WriteByte(int8(keyType)) + if e != nil { + return e + } + e = p.WriteByte(int8(valueType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +// WriteMapEnd ... +func (p *BinaryProtocol) WriteMapEnd() error { + return nil +} + +// WriteListBegin ... +func (p *BinaryProtocol) WriteListBegin(elemType TType, size int) error { + e := p.WriteByte(int8(elemType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +// WriteListEnd ... +func (p *BinaryProtocol) WriteListEnd() error { + return nil +} + +// WriteSetBegin ... +func (p *BinaryProtocol) WriteSetBegin(elemType TType, size int) error { + e := p.WriteByte(int8(elemType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +// WriteSetEnd ... +func (p *BinaryProtocol) WriteSetEnd() error { + return nil +} + +// WriteBool ... +func (p *BinaryProtocol) WriteBool(value bool) error { + if value { + return p.WriteByte(1) + } + return p.WriteByte(0) +} + +// WriteByte ... +func (p *BinaryProtocol) WriteByte(value int8) error { + v, err := p.malloc(1) + if err != nil { + return err + } + v[0] = byte(value) + return err +} + +// WriteI16 ... +func (p *BinaryProtocol) WriteI16(value int16) error { + v, err := p.malloc(2) + if err != nil { + return err + } + binary.BigEndian.PutUint16(v, uint16(value)) + return err +} + +// WriteI32 ... +func (p *BinaryProtocol) WriteI32(value int32) error { + v, err := p.malloc(4) + if err != nil { + return err + } + binary.BigEndian.PutUint32(v, uint32(value)) + return err +} + +// WriteI64 ... +func (p *BinaryProtocol) WriteI64(value int64) error { + v, err := p.malloc(8) + if err != nil { + return err + } + binary.BigEndian.PutUint64(v, uint64(value)) + return err +} + +// WriteDouble ... +func (p *BinaryProtocol) WriteDouble(value float64) error { + return p.WriteI64(int64(math.Float64bits(value))) +} + +// WriteString ... +func (p *BinaryProtocol) WriteString(value string) error { + len := len(value) + e := p.WriteI32(int32(len)) + if e != nil { + return e + } + _, e = p.trans.WriteString(value) + return e +} + +// WriteBinary ... +func (p *BinaryProtocol) WriteBinary(value []byte) error { + e := p.WriteI32(int32(len(value))) + if e != nil { + return e + } + _, e = p.trans.WriteBinary(value) + return e +} + +// malloc ... +func (p *BinaryProtocol) malloc(size int) ([]byte, error) { + buf, err := p.trans.Malloc(size) + if err != nil { + return buf, perrors.NewProtocolError(err) + } + return buf, nil +} + +/** + * Reading methods + */ + +// ReadMessageBegin ... +func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) { + size, e := p.ReadI32() + if e != nil { + return "", typeID, 0, perrors.NewProtocolError(e) + } + if size > 0 { + return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin") + } + typeID = TMessageType(size & 0x0ff) + version := int64(int64(size) & VERSION_MASK) + if version != VERSION_1 { + return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin") + } + name, e = p.ReadString() + if e != nil { + return name, typeID, seqID, perrors.NewProtocolError(e) + } + seqID, e = p.ReadI32() + if e != nil { + return name, typeID, seqID, perrors.NewProtocolError(e) + } + return name, typeID, seqID, nil +} + +// ReadMessageEnd ... +func (p *BinaryProtocol) ReadMessageEnd() error { + return nil +} + +// ReadStructBegin ... +func (p *BinaryProtocol) ReadStructBegin() (name string, err error) { + return +} + +// ReadStructEnd ... +func (p *BinaryProtocol) ReadStructEnd() error { + return nil +} + +// ReadFieldBegin ... +func (p *BinaryProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) { + t, err := p.ReadByte() + typeID = TType(t) + if err != nil { + return name, typeID, id, err + } + if t != STOP { + id, err = p.ReadI16() + } + return name, typeID, id, err +} + +// ReadFieldEnd ... +func (p *BinaryProtocol) ReadFieldEnd() error { + return nil +} + +// ReadMapBegin ... +func (p *BinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { + k, e := p.ReadByte() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + kType = TType(k) + v, e := p.ReadByte() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + vType = TType(v) + size32, e := p.ReadI32() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + if size32 < 0 { + err = perrors.InvalidDataLength + return + } + size = int(size32) + return kType, vType, size, nil +} + +// ReadMapEnd ... +func (p *BinaryProtocol) ReadMapEnd() error { + return nil +} + +// ReadListBegin ... +func (p *BinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { + b, e := p.ReadByte() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + elemType = TType(b) + size32, e := p.ReadI32() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + if size32 < 0 { + err = perrors.InvalidDataLength + return + } + size = int(size32) + + return +} + +// ReadListEnd ... +func (p *BinaryProtocol) ReadListEnd() error { + return nil +} + +// ReadSetBegin ... +func (p *BinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { + b, e := p.ReadByte() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + elemType = TType(b) + size32, e := p.ReadI32() + if e != nil { + err = perrors.NewProtocolError(e) + return + } + if size32 < 0 { + err = perrors.InvalidDataLength + return + } + size = int(size32) + return elemType, size, nil +} + +// ReadSetEnd ... +func (p *BinaryProtocol) ReadSetEnd() error { + return nil +} + +// ReadBool ... +func (p *BinaryProtocol) ReadBool() (bool, error) { + b, e := p.ReadByte() + v := true + if b != 1 { + v = false + } + return v, e +} + +// ReadByte ... +func (p *BinaryProtocol) ReadByte() (value int8, err error) { + buf, err := p.next(1) + if err != nil { + return value, err + } + return int8(buf[0]), err +} + +// ReadI16 ... +func (p *BinaryProtocol) ReadI16() (value int16, err error) { + buf, err := p.next(2) + if err != nil { + return value, err + } + value = int16(binary.BigEndian.Uint16(buf)) + return value, err +} + +// ReadI32 ... +func (p *BinaryProtocol) ReadI32() (value int32, err error) { + buf, err := p.next(4) + if err != nil { + return value, err + } + value = int32(binary.BigEndian.Uint32(buf)) + return value, err +} + +// ReadI64 ... +func (p *BinaryProtocol) ReadI64() (value int64, err error) { + buf, err := p.next(8) + if err != nil { + return value, err + } + value = int64(binary.BigEndian.Uint64(buf)) + return value, err +} + +// ReadDouble ... +func (p *BinaryProtocol) ReadDouble() (value float64, err error) { + buf, err := p.next(8) + if err != nil { + return value, err + } + value = math.Float64frombits(binary.BigEndian.Uint64(buf)) + return value, err +} + +// ReadString ... +func (p *BinaryProtocol) ReadString() (value string, err error) { + size, e := p.ReadI32() + if e != nil { + return "", e + } + if size < 0 { + err = perrors.InvalidDataLength + return + } + value, err = p.trans.ReadString(int(size)) + if err != nil { + return value, perrors.NewProtocolError(err) + } + return value, nil +} + +// ReadBinary ... +func (p *BinaryProtocol) ReadBinary() ([]byte, error) { + size, e := p.ReadI32() + if e != nil { + return nil, e + } + if size < 0 { + return nil, perrors.InvalidDataLength + } + return p.trans.ReadBinary(int(size)) +} + +// Flush ... +func (p *BinaryProtocol) Flush(ctx context.Context) (err error) { + err = p.trans.Flush() + if err != nil { + return perrors.NewProtocolError(err) + } + return nil +} + +// Skip ... +func (p *BinaryProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +// Transport ... +func (p *BinaryProtocol) Transport() TTransport { + return ttransportByteBuffer{p.trans} +} + +// ByteBuffer ... +func (p *BinaryProtocol) ByteBuffer() byteBuffer { + return p.trans +} + +// next ... +func (p *BinaryProtocol) next(size int) ([]byte, error) { + buf, err := p.trans.Next(size) + if err != nil { + return buf, perrors.NewProtocolError(err) + } + return buf, nil +} + +// ttransportByteBuffer ... +// for exposing remote.ByteBuffer via p.Transport(), +// mainly for testing purpose, see internal/mocks/athrift/utils.go +type ttransportByteBuffer struct { + byteBuffer +} + +func (ttransportByteBuffer) Close() error { panic("not implemented") } +func (ttransportByteBuffer) Flush(ctx context.Context) (err error) { panic("not implemented") } +func (ttransportByteBuffer) IsOpen() bool { panic("not implemented") } +func (ttransportByteBuffer) Open() error { panic("not implemented") } +func (p ttransportByteBuffer) RemainingBytes() uint64 { return uint64(p.ReadableLen()) } diff --git a/pkg/remote/codec/thrift/deprecated.go b/pkg/remote/codec/thrift/deprecated.go index f491c8899c..b6f0029d9f 100644 --- a/pkg/remote/codec/thrift/deprecated.go +++ b/pkg/remote/codec/thrift/deprecated.go @@ -17,14 +17,8 @@ package thrift import ( - "context" - "encoding/binary" - "math" - "sync" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) // MessageReader read from athrift.TProtocol @@ -46,498 +40,12 @@ func UnmarshalThriftException(tProt athrift.TProtocol) error { return unmarshalThriftException(tProt.Transport()) } -var bpPool = sync.Pool{ - New: func() interface{} { - return &BinaryProtocol{} - }, -} - // BinaryProtocol ... // Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol -type BinaryProtocol struct { - trans remote.ByteBuffer -} - -var _ athrift.TProtocol = (*BinaryProtocol)(nil) +type BinaryProtocol = athrift.BinaryProtocol // NewBinaryProtocol ... // Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol -func NewBinaryProtocol(t remote.ByteBuffer) *BinaryProtocol { - bp := bpPool.Get().(*BinaryProtocol) - bp.trans = t - return bp -} - -// Recycle ... -func (p *BinaryProtocol) Recycle() { - p.trans = nil - bpPool.Put(p) -} - -/** - * Writing Methods - */ - -// WriteMessageBegin ... -func (p *BinaryProtocol) WriteMessageBegin(name string, typeID athrift.TMessageType, seqID int32) error { - version := uint32(athrift.VERSION_1) | uint32(typeID) - e := p.WriteI32(int32(version)) - if e != nil { - return e - } - e = p.WriteString(name) - if e != nil { - return e - } - e = p.WriteI32(seqID) - return e -} - -// WriteMessageEnd ... -func (p *BinaryProtocol) WriteMessageEnd() error { - return nil -} - -// WriteStructBegin ... -func (p *BinaryProtocol) WriteStructBegin(name string) error { - return nil -} - -// WriteStructEnd ... -func (p *BinaryProtocol) WriteStructEnd() error { - return nil -} - -// WriteFieldBegin ... -func (p *BinaryProtocol) WriteFieldBegin(name string, typeID athrift.TType, id int16) error { - e := p.WriteByte(int8(typeID)) - if e != nil { - return e - } - e = p.WriteI16(id) - return e -} - -// WriteFieldEnd ... -func (p *BinaryProtocol) WriteFieldEnd() error { - return nil -} - -// WriteFieldStop ... -func (p *BinaryProtocol) WriteFieldStop() error { - e := p.WriteByte(athrift.STOP) - return e -} - -// WriteMapBegin ... -func (p *BinaryProtocol) WriteMapBegin(keyType, valueType athrift.TType, size int) error { - e := p.WriteByte(int8(keyType)) - if e != nil { - return e - } - e = p.WriteByte(int8(valueType)) - if e != nil { - return e - } - e = p.WriteI32(int32(size)) - return e -} - -// WriteMapEnd ... -func (p *BinaryProtocol) WriteMapEnd() error { - return nil -} - -// WriteListBegin ... -func (p *BinaryProtocol) WriteListBegin(elemType athrift.TType, size int) error { - e := p.WriteByte(int8(elemType)) - if e != nil { - return e - } - e = p.WriteI32(int32(size)) - return e -} - -// WriteListEnd ... -func (p *BinaryProtocol) WriteListEnd() error { - return nil -} - -// WriteSetBegin ... -func (p *BinaryProtocol) WriteSetBegin(elemType athrift.TType, size int) error { - e := p.WriteByte(int8(elemType)) - if e != nil { - return e - } - e = p.WriteI32(int32(size)) - return e -} - -// WriteSetEnd ... -func (p *BinaryProtocol) WriteSetEnd() error { - return nil -} - -// WriteBool ... -func (p *BinaryProtocol) WriteBool(value bool) error { - if value { - return p.WriteByte(1) - } - return p.WriteByte(0) -} - -// WriteByte ... -func (p *BinaryProtocol) WriteByte(value int8) error { - v, err := p.malloc(1) - if err != nil { - return err - } - v[0] = byte(value) - return err -} - -// WriteI16 ... -func (p *BinaryProtocol) WriteI16(value int16) error { - v, err := p.malloc(2) - if err != nil { - return err - } - binary.BigEndian.PutUint16(v, uint16(value)) - return err -} - -// WriteI32 ... -func (p *BinaryProtocol) WriteI32(value int32) error { - v, err := p.malloc(4) - if err != nil { - return err - } - binary.BigEndian.PutUint32(v, uint32(value)) - return err -} - -// WriteI64 ... -func (p *BinaryProtocol) WriteI64(value int64) error { - v, err := p.malloc(8) - if err != nil { - return err - } - binary.BigEndian.PutUint64(v, uint64(value)) - return err -} - -// WriteDouble ... -func (p *BinaryProtocol) WriteDouble(value float64) error { - return p.WriteI64(int64(math.Float64bits(value))) -} - -// WriteString ... -func (p *BinaryProtocol) WriteString(value string) error { - len := len(value) - e := p.WriteI32(int32(len)) - if e != nil { - return e - } - _, e = p.trans.WriteString(value) - return e -} - -// WriteBinary ... -func (p *BinaryProtocol) WriteBinary(value []byte) error { - e := p.WriteI32(int32(len(value))) - if e != nil { - return e - } - _, e = p.trans.WriteBinary(value) - return e -} - -// malloc ... -func (p *BinaryProtocol) malloc(size int) ([]byte, error) { - buf, err := p.trans.Malloc(size) - if err != nil { - return buf, perrors.NewProtocolError(err) - } - return buf, nil -} - -/** - * Reading methods - */ - -// ReadMessageBegin ... -func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID athrift.TMessageType, seqID int32, err error) { - size, e := p.ReadI32() - if e != nil { - return "", typeID, 0, perrors.NewProtocolError(e) - } - if size > 0 { - return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin") - } - typeID = athrift.TMessageType(size & 0x0ff) - version := int64(int64(size) & athrift.VERSION_MASK) - if version != athrift.VERSION_1 { - return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin") - } - name, e = p.ReadString() - if e != nil { - return name, typeID, seqID, perrors.NewProtocolError(e) - } - seqID, e = p.ReadI32() - if e != nil { - return name, typeID, seqID, perrors.NewProtocolError(e) - } - return name, typeID, seqID, nil -} - -// ReadMessageEnd ... -func (p *BinaryProtocol) ReadMessageEnd() error { - return nil -} - -// ReadStructBegin ... -func (p *BinaryProtocol) ReadStructBegin() (name string, err error) { - return -} - -// ReadStructEnd ... -func (p *BinaryProtocol) ReadStructEnd() error { - return nil -} - -// ReadFieldBegin ... -func (p *BinaryProtocol) ReadFieldBegin() (name string, typeID athrift.TType, id int16, err error) { - t, err := p.ReadByte() - typeID = athrift.TType(t) - if err != nil { - return name, typeID, id, err - } - if t != athrift.STOP { - id, err = p.ReadI16() - } - return name, typeID, id, err -} - -// ReadFieldEnd ... -func (p *BinaryProtocol) ReadFieldEnd() error { - return nil -} - -// ReadMapBegin ... -func (p *BinaryProtocol) ReadMapBegin() (kType, vType athrift.TType, size int, err error) { - k, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - kType = athrift.TType(k) - v, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - vType = athrift.TType(v) - size32, e := p.ReadI32() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - if size32 < 0 { - err = perrors.InvalidDataLength - return - } - size = int(size32) - return kType, vType, size, nil -} - -// ReadMapEnd ... -func (p *BinaryProtocol) ReadMapEnd() error { - return nil -} - -// ReadListBegin ... -func (p *BinaryProtocol) ReadListBegin() (elemType athrift.TType, size int, err error) { - b, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - elemType = athrift.TType(b) - size32, e := p.ReadI32() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - if size32 < 0 { - err = perrors.InvalidDataLength - return - } - size = int(size32) - - return -} - -// ReadListEnd ... -func (p *BinaryProtocol) ReadListEnd() error { - return nil -} - -// ReadSetBegin ... -func (p *BinaryProtocol) ReadSetBegin() (elemType athrift.TType, size int, err error) { - b, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - elemType = athrift.TType(b) - size32, e := p.ReadI32() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - if size32 < 0 { - err = perrors.InvalidDataLength - return - } - size = int(size32) - return elemType, size, nil -} - -// ReadSetEnd ... -func (p *BinaryProtocol) ReadSetEnd() error { - return nil -} - -// ReadBool ... -func (p *BinaryProtocol) ReadBool() (bool, error) { - b, e := p.ReadByte() - v := true - if b != 1 { - v = false - } - return v, e -} - -// ReadByte ... -func (p *BinaryProtocol) ReadByte() (value int8, err error) { - buf, err := p.next(1) - if err != nil { - return value, err - } - return int8(buf[0]), err -} - -// ReadI16 ... -func (p *BinaryProtocol) ReadI16() (value int16, err error) { - buf, err := p.next(2) - if err != nil { - return value, err - } - value = int16(binary.BigEndian.Uint16(buf)) - return value, err -} - -// ReadI32 ... -func (p *BinaryProtocol) ReadI32() (value int32, err error) { - buf, err := p.next(4) - if err != nil { - return value, err - } - value = int32(binary.BigEndian.Uint32(buf)) - return value, err -} - -// ReadI64 ... -func (p *BinaryProtocol) ReadI64() (value int64, err error) { - buf, err := p.next(8) - if err != nil { - return value, err - } - value = int64(binary.BigEndian.Uint64(buf)) - return value, err -} - -// ReadDouble ... -func (p *BinaryProtocol) ReadDouble() (value float64, err error) { - buf, err := p.next(8) - if err != nil { - return value, err - } - value = math.Float64frombits(binary.BigEndian.Uint64(buf)) - return value, err -} - -// ReadString ... -func (p *BinaryProtocol) ReadString() (value string, err error) { - size, e := p.ReadI32() - if e != nil { - return "", e - } - if size < 0 { - err = perrors.InvalidDataLength - return - } - value, err = p.trans.ReadString(int(size)) - if err != nil { - return value, perrors.NewProtocolError(err) - } - return value, nil -} - -// ReadBinary ... -func (p *BinaryProtocol) ReadBinary() ([]byte, error) { - size, e := p.ReadI32() - if e != nil { - return nil, e - } - if size < 0 { - return nil, perrors.InvalidDataLength - } - return p.trans.ReadBinary(int(size)) -} - -// Flush ... -func (p *BinaryProtocol) Flush(ctx context.Context) (err error) { - err = p.trans.Flush() - if err != nil { - return perrors.NewProtocolError(err) - } - return nil -} - -// Skip ... -func (p *BinaryProtocol) Skip(fieldType athrift.TType) (err error) { - return athrift.SkipDefaultDepth(p, fieldType) -} - -// ttransportByteBuffer ... -// for exposing remote.ByteBuffer via p.Transport(), -// mainly for testing purpose, see internal/mocks/athrift/utils.go -type ttransportByteBuffer struct { - remote.ByteBuffer -} - -func (ttransportByteBuffer) Close() error { panic("not implemented") } -func (ttransportByteBuffer) Flush(ctx context.Context) (err error) { panic("not implemented") } -func (ttransportByteBuffer) IsOpen() bool { panic("not implemented") } -func (ttransportByteBuffer) Open() error { panic("not implemented") } -func (p ttransportByteBuffer) RemainingBytes() uint64 { return uint64(p.ReadableLen()) } - -// Transport ... -func (p *BinaryProtocol) Transport() athrift.TTransport { - return ttransportByteBuffer{p.trans} -} - -// ByteBuffer ... -func (p *BinaryProtocol) ByteBuffer() remote.ByteBuffer { - return p.trans -} - -// next ... -func (p *BinaryProtocol) next(size int) ([]byte, error) { - buf, err := p.trans.Next(size) - if err != nil { - return buf, perrors.NewProtocolError(err) - } - return buf, nil +func NewBinaryProtocol(t remote.ByteBuffer) *athrift.BinaryProtocol { + return athrift.NewBinaryProtocol(t) } diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index e2732f7b5e..12d1b668bb 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -72,7 +72,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ // fallback to old thrift way (slow) buf := bytes.NewBuffer(make([]byte, 0, marshalThriftBufferSize)) - if err := apache.ThriftWrite(apache.NewBufferTransport(buf), data); err != nil { + if err := apache.ThriftWrite(buf, data); err != nil { return nil, err } return buf.Bytes(), nil @@ -212,7 +212,7 @@ func decodeBasicThriftData(trans remote.ByteBuffer, data interface{}) error { if err = verifyUnmarshalBasicThriftDataType(data); err != nil { return err } - if err = apache.ThriftRead(apache.NewDefaultTransport(trans), data); err != nil { + if err = apache.ThriftRead(trans, data); err != nil { return remote.NewTransError(remote.ProtocolError, err) } return nil diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index 806e805f8a..0ce7af5eb4 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -48,7 +48,7 @@ func (t *ThriftMessageCodec) Encode(method string, msgType athrift.TMessageType, _ = thrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seqID) buf := &bytes.Buffer{} buf.Write(b) - if err := apache.ThriftWrite(apache.NewBufferTransport(buf), msg); err != nil { + if err := apache.ThriftWrite(buf, msg); err != nil { return nil, err } return buf.Bytes(), nil @@ -71,7 +71,7 @@ func (t *ThriftMessageCodec) Decode(b []byte, msg athrift.TStruct) (method strin } return } - err = apache.ThriftRead(apache.NewBufferTransport(bytes.NewBuffer(b)), msg) + err = apache.ThriftRead(bytes.NewBuffer(b), msg) return } @@ -79,7 +79,7 @@ func (t *ThriftMessageCodec) Decode(b []byte, msg athrift.TStruct) (method strin // Notice: Binary generic use Encode instead of Serialize. func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) ([]byte, error) { buf := &bytes.Buffer{} - if err := apache.ThriftWrite(apache.NewBufferTransport(buf), msg); err != nil { + if err := apache.ThriftWrite(buf, msg); err != nil { return nil, err } return buf.Bytes(), nil @@ -89,7 +89,7 @@ func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) ([]byte, error) { // Notice: Binary generic use Decode instead of Deserialize. func (t *ThriftMessageCodec) Deserialize(msg athrift.TStruct, b []byte) (err error) { buf := bytes.NewBuffer(b) - return apache.ThriftRead(apache.NewBufferTransport(buf), msg) + return apache.ThriftRead(buf, msg) } // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. diff --git a/pkg/utils/thrift_test.go b/pkg/utils/thrift_test.go index 072fb1bd61..6fea7afc0f 100644 --- a/pkg/utils/thrift_test.go +++ b/pkg/utils/thrift_test.go @@ -47,7 +47,7 @@ func TestRPCCodec(t *testing.T) { // decode method, seqID, err := rc.Decode(buf, bthrift.ToApacheCodec(&argsDecode1)) - test.Assert(t, err == nil) + test.Assert(t, err == nil, err) test.Assert(t, method == "mockMethod") test.Assert(t, seqID == 100) test.Assert(t, argsDecode1.Req.Msg == req1.Msg) From 0824d3cd61369c4598e9de7f66b3269ace2528c4 Mon Sep 17 00:00:00 2001 From: QihengZhou Date: Mon, 19 Aug 2024 10:06:34 +0800 Subject: [PATCH 44/70] feat(grpc): add GetTrailerMetadataFromCtx (#1491) --- pkg/remote/trans/nphttp2/meta_api.go | 9 ++++ pkg/remote/trans/nphttp2/meta_api_test.go | 51 +++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 pkg/remote/trans/nphttp2/meta_api_test.go diff --git a/pkg/remote/trans/nphttp2/meta_api.go b/pkg/remote/trans/nphttp2/meta_api.go index a0808b6e69..5553861d41 100644 --- a/pkg/remote/trans/nphttp2/meta_api.go +++ b/pkg/remote/trans/nphttp2/meta_api.go @@ -122,6 +122,15 @@ func GetHeaderMetadataFromCtx(ctx context.Context) *metadata.MD { return nil } +// GetTrailerMetadataFromCtx is used to get the metadata of stream Trailer from ctx. +func GetTrailerMetadataFromCtx(ctx context.Context) *metadata.MD { + trailer := ctx.Value(trailerKey{}) + if trailer != nil { + return trailer.(*metadata.MD) + } + return nil +} + // set header and trailer to the ctx by default. func receiveHeaderAndTrailer(ctx context.Context, conn net.Conn) context.Context { if md, err := conn.(hasHeader).Header(); err == nil { diff --git a/pkg/remote/trans/nphttp2/meta_api_test.go b/pkg/remote/trans/nphttp2/meta_api_test.go new file mode 100644 index 0000000000..849e438ad9 --- /dev/null +++ b/pkg/remote/trans/nphttp2/meta_api_test.go @@ -0,0 +1,51 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nphttp2 + +import ( + "context" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" +) + +func TestGetTrailerMetadataFromCtx(t *testing.T) { + // success + md := metadata.Pairs("k", "v") + ctx := GRPCTrailer(context.Background(), &md) + m := *GetTrailerMetadataFromCtx(ctx) + test.Assert(t, len(m["k"]) == 1) + test.Assert(t, m["k"][0] == "v") + + // failure + m2 := GetTrailerMetadataFromCtx(context.Background()) + test.Assert(t, m2 == nil) +} + +func TestGetHeaderMetadataFromCtx(t *testing.T) { + // success + md := metadata.Pairs("k", "v") + ctx := GRPCHeader(context.Background(), &md) + m := *GetHeaderMetadataFromCtx(ctx) + test.Assert(t, len(m["k"]) == 1) + test.Assert(t, m["k"][0] == "v") + + // failure + m2 := GetHeaderMetadataFromCtx(context.Background()) + test.Assert(t, m2 == nil) +} From 1256b7ddcaa69a904d3a8d967154423abc6db420 Mon Sep 17 00:00:00 2001 From: QihengZhou Date: Mon, 19 Aug 2024 20:50:10 +0800 Subject: [PATCH 45/70] perf: add option to enable spancache for fastpb (#1497) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 77a5258b4b..8e29d7b943 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/bytedance/sonic v1.12.1 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.3.0 - github.com/cloudwego/fastpb v0.0.4 + github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d github.com/cloudwego/localsession v0.0.2 diff --git a/go.sum b/go.sum index df310cee6e..4b76b022fe 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/cloudwego/configmanager v0.2.2 h1:sVrJB8gWYTlPV2OS3wcgJSO9F2/9Zbkmcm1 github.com/cloudwego/configmanager v0.2.2/go.mod h1:ppiyU+5TPLonE8qMVi/pFQk2eL3Q4P7d4hbiNJn6jwI= github.com/cloudwego/dynamicgo v0.3.0 h1:2/jOD3cMn8YVWGmVybrn74YulmhxW8d4BPyy9pja5eo= github.com/cloudwego/dynamicgo v0.3.0/go.mod h1:vPHEegW2xqjuDE8NAui+2D93RivFv18eWsyD9VRtORM= -github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4mg= -github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= +github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZU= +github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d h1:QBV/89XA0Mwlk6LQgLIDIf1vDMWSn9O2Xx1lJX7PRGI= From c8843e4365a6bbe956934c2fd121642de083747d Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:33:01 +0800 Subject: [PATCH 46/70] fix: return an unknown service/method exception to client correctly under multi_service server scenario (#1503) --- pkg/generic/json_test/generic_test.go | 13 +++++ .../json_test/idl/mock_unknown_method.thrift | 49 +++++++++++++++++++ pkg/remote/codec/util.go | 14 +++--- pkg/remote/message.go | 3 ++ 4 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 pkg/generic/json_test/idl/mock_unknown_method.thrift diff --git a/pkg/generic/json_test/generic_test.go b/pkg/generic/json_test/generic_test.go index dbe1804719..a439592b91 100644 --- a/pkg/generic/json_test/generic_test.go +++ b/pkg/generic/json_test/generic_test.go @@ -556,6 +556,19 @@ func testRegression(t *testing.T) { svr.Stop() } +func TestUnknownError(t *testing.T) { + addr := test.GetLocalAddress() + svr := initMockServer(t, new(mockImpl), addr) + + cli := initThriftClient(transport.TTHeader, t, addr, "./idl/mock_unknown_method.thrift", nil, nil, false) + resp, err := cli.GenericCall(context.Background(), "UnknownMethod", reqMsg) + test.Assert(t, resp == nil) + test.Assert(t, err != nil) + test.DeepEqual(t, err.Error(), "remote or network error[remote]: unknown service , method UnknownMethod") + + svr.Stop() +} + func initThriftMockClient(t *testing.T, tp transport.Protocol, enableDynamicGo bool, address string) genericclient.Client { var p generic.DescriptorProvider var err error diff --git a/pkg/generic/json_test/idl/mock_unknown_method.thrift b/pkg/generic/json_test/idl/mock_unknown_method.thrift new file mode 100644 index 0000000000..8a160d135d --- /dev/null +++ b/pkg/generic/json_test/idl/mock_unknown_method.thrift @@ -0,0 +1,49 @@ +include "base.thrift" +include "self_ref.thrift" +include "extend.thrift" +namespace go kitex.test.server + +enum FOO { + A = 1; +} + +struct InnerBase { + 255: base.Base Base, +} + +struct ExampleReq { + 1: required string Msg, + 2: FOO Foo, + 3: InnerBase InnerBase, + 4: optional i8 I8, + 5: optional i16 I16, + 6: optional i32 I32, + 7: optional i64 I64, + 8: optional double Double, + 255: base.Base Base, +} +struct ExampleResp { + 1: required string Msg, + 2: string required_field, + 3: optional i64 num (api.js_conv="true"), + 4: optional i8 I8, + 5: optional i16 I16, + 6: optional i32 I32, + 7: optional i64 I64, + 8: optional double Double, + 255: base.BaseResp BaseResp, +} +exception Exception { + 1: i32 code + 2: string msg +} + +struct A { + 1: A self + 2: self_ref.A a +} + +service ExampleService extends extend.ExtendService { + ExampleResp ExampleMethod(1: ExampleReq req)throws(1: Exception err) + ExampleResp UnknownMethod(1: ExampleReq req) +} \ No newline at end of file diff --git a/pkg/remote/codec/util.go b/pkg/remote/codec/util.go index 78b13347f2..cee7082665 100644 --- a/pkg/remote/codec/util.go +++ b/pkg/remote/codec/util.go @@ -45,17 +45,17 @@ func SetOrCheckMethodName(methodName string, message remote.Message) error { if message.RPCRole() == remote.Client { return fmt.Errorf("wrong method name, expect=%s, actual=%s", callMethodName, methodName) } + inkSetter, ok := ink.(rpcinfo.InvocationSetter) + if !ok { + return errors.New("the interface Invocation doesn't implement InvocationSetter") + } + inkSetter.SetMethodName(methodName) svcInfo, err := message.SpecifyServiceInfo(ink.ServiceName(), methodName) if err != nil { return err } - if ink, ok := ink.(rpcinfo.InvocationSetter); ok { - ink.SetMethodName(methodName) - ink.SetPackageName(svcInfo.GetPackageName()) - ink.SetServiceName(svcInfo.ServiceName) - } else { - return errors.New("the interface Invocation doesn't implement InvocationSetter") - } + inkSetter.SetPackageName(svcInfo.GetPackageName()) + inkSetter.SetServiceName(svcInfo.ServiceName) // unknown method doesn't set methodName for RPCInfo.To(), or lead inconsistent with old version rpcinfo.AsMutableEndpointInfo(ri.To()).SetMethod(methodName) diff --git a/pkg/remote/message.go b/pkg/remote/message.go index c2401fb210..9d9e53fd74 100644 --- a/pkg/remote/message.go +++ b/pkg/remote/message.go @@ -197,6 +197,9 @@ func (m *message) SpecifyServiceInfo(svcName, methodName string) (*serviceinfo.S if svcInfo == nil { return nil, NewTransErrorWithMsg(UnknownService, fmt.Sprintf("unknown service %s, method %s", svcName, methodName)) } + if _, ok := svcInfo.Methods[methodName]; !ok { + return nil, NewTransErrorWithMsg(UnknownMethod, fmt.Sprintf("unknown method %s (service %s)", methodName, svcName)) + } m.targetSvcInfo = svcInfo return svcInfo, nil } From 1937f68a32268a96e6c97c2a8c917b00c9dcb6d4 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 23 Aug 2024 09:15:26 +0400 Subject: [PATCH 47/70] perf(grpc): zero allocation in hot path (#1504) * use `internal/utils/safemcache` for `Malloc` / `Free` * disable dynamic window * reuse `dataFrame` and `itemNode` objects --- go.mod | 2 +- go.sum | 2 + internal/utils/safemcache/safemcache.go | 83 +++++++++++++++++++ internal/utils/safemcache/safemcache_test.go | 46 ++++++++++ pkg/remote/codec/grpc/grpc.go | 15 ++-- pkg/remote/codec/grpc/grpc_compress.go | 13 ++- pkg/remote/codec/thrift/thrift_data.go | 19 +++-- pkg/remote/codec/thrift/thrift_frugal.go | 5 +- pkg/remote/trans/nphttp2/grpc/controlbuf.go | 76 +++++++++++++---- pkg/remote/trans/nphttp2/grpc/http2_client.go | 20 ++--- pkg/remote/trans/nphttp2/grpc/http2_server.go | 22 ++--- .../trans/nphttp2/grpc/transport_test.go | 16 ++-- pkg/remote/trans/nphttp2/server_conn.go | 2 +- 13 files changed, 254 insertions(+), 67 deletions(-) create mode 100644 internal/utils/safemcache/safemcache.go create mode 100644 internal/utils/safemcache/safemcache_test.go diff --git a/go.mod b/go.mod index 8e29d7b943..27d228b424 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/apache/thrift v0.13.0 - github.com/bytedance/gopkg v0.1.0 + github.com/bytedance/gopkg v0.1.1-0.20240822091137-ff3e2edbc319 github.com/bytedance/sonic v1.12.1 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.3.0 diff --git a/go.sum b/go.sum index 4b76b022fe..08de356b34 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQ github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.1-0.20240822091137-ff3e2edbc319 h1:XMLnw5HdHWpmbyiIMWlC7c6GrkHoQHIMZX3vHeYMZiw= +github.com/bytedance/gopkg v0.1.1-0.20240822091137-ff3e2edbc319/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.12.1 h1:jWl5Qz1fy7X1ioY74WqO0KjAMtAGQs4sYnjiEBiyX24= github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= diff --git a/internal/utils/safemcache/safemcache.go b/internal/utils/safemcache/safemcache.go new file mode 100644 index 0000000000..584123ee63 --- /dev/null +++ b/internal/utils/safemcache/safemcache.go @@ -0,0 +1,83 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package safemcache wraps mcache for unsafe context. +// It's only used by GRPC for now, should be removed in the future after refactoring GRPC. +package safemcache + +import ( + "unsafe" + + "github.com/bytedance/gopkg/lang/mcache" +) + +const ( + footerLen = 8 + footerMagic = uint64(0xBADC0DEBADC0DEFF) +) + +type sliceHeader struct { + Data unsafe.Pointer + Len int + Cap int +} + +func (h *sliceHeader) Footer() uint64 { + return *(*uint64)(unsafe.Add(h.Data, h.Cap-footerLen)) +} + +func (h *sliceHeader) SetFooter(v uint64) { + *(*uint64)(unsafe.Add(h.Data, h.Cap-footerLen)) = v +} + +// Malloc warps `mcache.Malloc` for unsafe context. +// You should use `mcache.Malloc` directly if lifecycle of buf is clear +// It appends a magic number to the end of buffer and checks it when `Free`. +// Use `Cap` to get the cap of a buf created by `Malloc` +func Malloc(size int) []byte { + ret := mcache.Malloc(size + footerLen) + h := (*sliceHeader)(unsafe.Pointer(&ret)) + h.SetFooter(footerMagic) + return ret[:size] +} + +// Cap returns the max cap of a buf can be resized to. +// See comment of `Malloc` for details +func Cap(buf []byte) int { + if cap(buf) < footerLen { + return cap(buf) // not created by `Malloc`? + } + h := (*sliceHeader)(unsafe.Pointer(&buf)) + if h.Footer() == footerMagic { + return cap(buf) - footerLen + } + return cap(buf) +} + +// Free does nothing if buf is not created by `Malloc`. +// see comment of `Malloc` for details +func Free(buf []byte) { + c := cap(buf) + if c < footerLen { + return + } + h := (*sliceHeader)(unsafe.Pointer(&buf)) + if h.Footer() != footerMagic { + return + } + h.SetFooter(0) // reset footer before returning it to pool + mcache.Free(buf) +} diff --git a/internal/utils/safemcache/safemcache_test.go b/internal/utils/safemcache/safemcache_test.go new file mode 100644 index 0000000000..7f3eadbe82 --- /dev/null +++ b/internal/utils/safemcache/safemcache_test.go @@ -0,0 +1,46 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package safemcache + +import ( + "testing" + "unsafe" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestMallocFree(t *testing.T) { + // case: normal + b := Malloc(1) + test.Assert(t, len(b) == 1) + + h := (*sliceHeader)(unsafe.Pointer(&b)) + test.Assert(t, h.Footer() == footerMagic) + + Free(b) + test.Assert(t, h.Footer() == 0) + + // case: magic not match + b = Malloc(1) + test.Assert(t, len(b) == 1) + h = (*sliceHeader)(unsafe.Pointer(&b)) + test.Assert(t, h.Footer() == footerMagic) + + h.SetFooter(2) + Free(b) // it will not work + test.Assert(t, h.Footer() == 2) +} diff --git a/pkg/remote/codec/grpc/grpc.go b/pkg/remote/codec/grpc/grpc.go index 53a0e3708a..5cbc94fc2a 100644 --- a/pkg/remote/codec/grpc/grpc.go +++ b/pkg/remote/codec/grpc/grpc.go @@ -22,10 +22,10 @@ import ( "errors" "fmt" - "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/fastpb" "google.golang.org/protobuf/proto" + "github.com/cloudwego/kitex/internal/utils/safemcache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" @@ -74,7 +74,7 @@ func NewGRPCCodec(opts ...CodecOption) remote.Codec { } func mallocWithFirstByteZeroed(size int) []byte { - data := mcache.Malloc(size) + data := safemcache.Malloc(size) data[0] = 0 // compressed flag = false return data } @@ -111,7 +111,7 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) return writer.WriteData(payload) } - payload = mcache.Malloc(size) + payload = safemcache.Malloc(size) t.FastWrite(payload) case marshaler: size := t.Size() @@ -123,7 +123,7 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) return writer.WriteData(payload) } - payload = mcache.Malloc(size) + payload = safemcache.Malloc(size) if _, err = t.MarshalTo(payload); err != nil { return err } @@ -156,9 +156,9 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo if err != nil { return err } - var header [dataFrameHeaderLen]byte + header := safemcache.Malloc(dataFrameHeaderLen) if isCompressed { - payload, err = compress(compressor, payload) + payload, err = compress(compressor, payload) // compress will `Free` payload if err != nil { return err } @@ -167,12 +167,11 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo header[0] = 0 } binary.BigEndian.PutUint32(header[1:dataFrameHeaderLen], uint32(len(payload))) - err = writer.WriteHeader(header[:]) + err = writer.WriteHeader(header) if err != nil { return err } return writer.WriteData(payload) - // TODO: recycle payload? } func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { diff --git a/pkg/remote/codec/grpc/grpc_compress.go b/pkg/remote/codec/grpc/grpc_compress.go index a8fe32a0a9..d6603ca1a5 100644 --- a/pkg/remote/codec/grpc/grpc_compress.go +++ b/pkg/remote/codec/grpc/grpc_compress.go @@ -23,13 +23,10 @@ import ( "errors" "io" - "github.com/bytedance/gopkg/lang/mcache" - - "github.com/cloudwego/kitex/pkg/rpcinfo" - - "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" - + "github.com/cloudwego/kitex/internal/utils/safemcache" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" + "github.com/cloudwego/kitex/pkg/rpcinfo" ) func getSendCompressor(ctx context.Context) (encoding.Compressor, error) { @@ -64,7 +61,9 @@ func decodeGRPCFrame(ctx context.Context, in remote.ByteBuffer) ([]byte, error) func compress(compressor encoding.Compressor, data []byte) ([]byte, error) { if len(data) != 0 { - defer mcache.Free(data) + // data is NOT always created by `safemcache.Malloc` or `mcache.Malloc` + // use `safemcache.Free` for safe ... + defer safemcache.Free(data) } cbuf := &bytes.Buffer{} z, err := compressor.Compress(cbuf) diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 12d1b668bb..4ff0ca9d9c 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -22,10 +22,10 @@ import ( "fmt" "io" - "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/apache" + "github.com/cloudwego/kitex/internal/utils/safemcache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) @@ -33,7 +33,17 @@ import ( const marshalThriftBufferSize = 1024 // MarshalThriftData only encodes the data (without the prepending methodName, msgType, seqId) -// It will allocate a new buffer and encode to it +// NOTE: +// it's used by grpc only, +// coz kitex grpc doesn't implements remote.Message and remote.ByteBuffer for rpc. +// +// for `FastWrite` or `FrugalWrite`, +// the buf is created by `github.com/bytedance/gopkg/lang/mcache`, `Free` it at your own risk. +// +// for internals, actually, +// coz it's hard to control the lifecycle of a returned buf, we use a safe version of `mcache` which is +// compatible with `mcache` to make sure `Free` would not have any side effects. +// see `github.com/cloudwego/kitex/internal/utils/safemcache` for details. func MarshalThriftData(ctx context.Context, codec remote.PayloadCodec, data interface{}) ([]byte, error) { c, ok := codec.(*thriftCodec) if !ok { @@ -42,8 +52,7 @@ func MarshalThriftData(ctx context.Context, codec remote.PayloadCodec, data inte return c.marshalThriftData(ctx, data) } -// marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId) -// It will allocate a new buffer and encode to it +// NOTE: only used by `MarshalThriftData` func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([]byte, error) { // TODO(xiaost): Refactor the code after v0.11.0 is released. Unifying checking and fallback logic. @@ -55,7 +64,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ if c.IsSet(FastWrite) { if msg, ok := data.(thrift.FastCodec); ok { payloadSize := msg.BLength() - payload := mcache.Malloc(payloadSize) + payload := safemcache.Malloc(payloadSize) msg.FastWriteNocopy(payload, nil) return payload, nil } diff --git a/pkg/remote/codec/thrift/thrift_frugal.go b/pkg/remote/codec/thrift/thrift_frugal.go index aeae5c19ad..173095a28f 100644 --- a/pkg/remote/codec/thrift/thrift_frugal.go +++ b/pkg/remote/codec/thrift/thrift_frugal.go @@ -20,10 +20,10 @@ import ( "fmt" "reflect" - "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/frugal" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/internal/utils/safemcache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) @@ -76,9 +76,10 @@ func (c thriftCodec) hyperMarshal(out remote.ByteBuffer, methodName string, msgT return nil } +// NOTE: only used by `marshalThriftData` func (c thriftCodec) hyperMarshalBody(data interface{}) (buf []byte, err error) { objectLen := frugal.EncodedSize(data) - buf = mcache.Malloc(objectLen) + buf = safemcache.Malloc(objectLen) _, err = frugal.EncodeObject(buf, nil, data) return buf, err } diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf.go b/pkg/remote/trans/nphttp2/grpc/controlbuf.go index 4fd0672aa0..7dd4701124 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf.go @@ -27,10 +27,10 @@ import ( "sync" "sync/atomic" - "github.com/bytedance/gopkg/lang/mcache" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "github.com/cloudwego/kitex/internal/utils/safemcache" "github.com/cloudwego/kitex/pkg/klog" ) @@ -43,13 +43,26 @@ type itemNode struct { next *itemNode } +const maxFreeItemNodes = 100 + type itemList struct { head *itemNode tail *itemNode + + free *itemNode + nfree int } func (il *itemList) enqueue(i interface{}) { - n := &itemNode{it: i} + var n *itemNode + if il.free != nil { + // pop the 1st free item + n, il.free = il.free, il.free.next + *n = itemNode{it: i} + il.nfree-- + } else { + n = &itemNode{it: i} + } if il.tail == nil { il.head, il.tail = n, n return @@ -68,11 +81,18 @@ func (il *itemList) dequeue() interface{} { if il.head == nil { return nil } - i := il.head.it + unused := il.head + i := unused.it il.head = il.head.next if il.head == nil { il.tail = nil } + + if il.nfree < maxFreeItemNodes { + // add to head of free list + il.free, unused.next = unused, il.free + il.nfree++ + } return i } @@ -140,12 +160,39 @@ type dataFrame struct { // you can assign the header to h and the payload to the d; // or just assign the header + payload together to the d. // In other words, h = nil means d = header + payload. - h []byte - d []byte - dcache []byte // dcache is the origin d created by mcache, this ptr is only used for kitex - // onEachWrite is called every time - // a part of d is written out. - onEachWrite func() + h []byte + d []byte + + // the header and data are most likely from pkg/remote/codec/grpc which created by `safemcache`. + // we keep original []byte for mcache recycling coz h and d will move forward when writes. + // make sure only use `safemcache.Free` to recycle the buffers, + // coz it's NOT always created by `safemcache.Malloc` or `mcache.Malloc`. + originH []byte + originD []byte + + // resetPingStrikes is stored 1 every time a part of d is written out. + // not holding setResetPingStrikes() for performance concern, + // coz it will cause one closure allocation for each dataFrame. + // it replaces the original impl of `onEachWrite` which calls `setResetPingStrikes` of `http2Server` + resetPingStrikes *uint32 +} + +var poolDataFrame = sync.Pool{ + New: func() interface{} { + return &dataFrame{} + }, +} + +func newDataFrame() *dataFrame { + p := poolDataFrame.Get().(*dataFrame) + *p = dataFrame{} // reset all fields + return p +} + +func (p *dataFrame) Release() { + safemcache.Free(p.originH) + safemcache.Free(p.originD) + poolDataFrame.Put(p) } func (*dataFrame) isTransportResponseFrame() bool { return false } @@ -848,7 +895,10 @@ func (l *loopyWriter) processData() (bool, error) { if err := l.framer.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { return false, err } + str.itl.dequeue() // remove the empty data item from stream + dataItem.Release() + if str.itl.isEmpty() { str.state = empty } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. @@ -902,8 +952,8 @@ func (l *loopyWriter) processData() (bool, error) { if dataItem.endStream && len(dataItem.h)+len(dataItem.d) <= size { endStream = true } - if dataItem.onEachWrite != nil { - dataItem.onEachWrite() + if dataItem.resetPingStrikes != nil { + atomic.StoreUint32(dataItem.resetPingStrikes, 1) } if err := l.framer.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil { return false, err @@ -914,10 +964,8 @@ func (l *loopyWriter) processData() (bool, error) { dataItem.d = dataItem.d[dSize:] if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out. - if len(dataItem.dcache) > 0 { - mcache.Free(dataItem.dcache) - } str.itl.dequeue() + dataItem.Release() } if str.itl.isEmpty() { str.state = empty diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 9f141e0ac0..109a5538dc 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -191,7 +191,9 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, t.initialWindowSize = opts.InitialWindowSize dynamicWindow = false } - if dynamicWindow { + if false && dynamicWindow { + // we force disable dynamic window here coz it's sending too many ping frames... + // and it may not work as expected when running on top of netpoll. t.bdpEst = &bdpEstimator{ bdp: initialWindowSize, updateFlowControl: t.updateFlowControl, @@ -656,15 +658,13 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { } else if s.getState() != streamActive { return errStreamDone } - df := &dataFrame{ - streamID: s.id, - endStream: opts.Last, - h: hdr, - d: data, - } - if len(hdr) == 0 && len(data) != 0 { - df.dcache = data - } + df := newDataFrame() + df.streamID = s.id + df.endStream = opts.Last + df.h = hdr + df.d = data + df.originH = df.h + df.originD = df.d if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { return err diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index 07cddf24d2..515c47ef9b 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -219,7 +219,9 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ bufferPool: newBufferPool(), } t.controlBuf = newControlBuffer(t.done) - if dynamicWindow { + if false && dynamicWindow { + // we force disable dynamic window here coz it's sending too many ping frames... + // and it may not work as expected when running on top of netpoll. t.bdpEst = &bdpEstimator{ bdp: initialWindowSize, updateFlowControl: t.updateFlowControl, @@ -716,6 +718,8 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } func (t *http2Server) setResetPingStrikes() { + // NOTE: if you're going to change this func + // update `resetPingStrikes` logic of `dataFrame` as well atomic.StoreUint32(&t.resetPingStrikes, 1) } @@ -831,15 +835,13 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { return ContextErr(s.ctx.Err()) } } - df := &dataFrame{ - streamID: s.id, - h: hdr, - d: data, - onEachWrite: t.setResetPingStrikes, - } - if len(hdr) == 0 && len(data) != 0 { - df.dcache = data - } + df := newDataFrame() + df.streamID = s.id + df.h = hdr + df.d = data + df.originH = df.h + df.originD = df.d + df.resetPingStrikes = &t.resetPingStrikes if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { select { case <-t.done: diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 19adc3b24f..7225018ad3 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -179,10 +179,9 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { } } conn.controlBuf.put(&dataFrame{ - streamID: s.id, - h: nil, - d: p, - onEachWrite: func() {}, + streamID: s.id, + h: nil, + d: p, }) sent += len(p) } @@ -973,11 +972,10 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { t.Fatalf("Failed to open stream: %v", err) } ct.controlBuf.put(&dataFrame{ - streamID: s.id, - endStream: false, - h: nil, - d: make([]byte, http2MaxFrameLen), - onEachWrite: func() {}, + streamID: s.id, + endStream: false, + h: nil, + d: make([]byte, http2MaxFrameLen), }) // Loop until the server side stream is created. var ss *Stream diff --git a/pkg/remote/trans/nphttp2/server_conn.go b/pkg/remote/trans/nphttp2/server_conn.go index f124ec1d30..271c3358c7 100644 --- a/pkg/remote/trans/nphttp2/server_conn.go +++ b/pkg/remote/trans/nphttp2/server_conn.go @@ -97,7 +97,7 @@ func (c *serverConn) Write(b []byte) (n int, err error) { func (c *serverConn) WriteFrame(hdr, data []byte) (n int, err error) { // server sets the END_STREAM flag in trailer when writeStatus - err = c.tr.Write(c.s, hdr, data, &grpc.Options{}) + err = c.tr.Write(c.s, hdr, data, nil) return len(hdr) + len(data), err } From be1a732d38026c9c1bffbd417801b51e31a3c52c Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:12:25 +0800 Subject: [PATCH 48/70] chore(generic): add an external method to create service info for generic streaming client (#1465) --- client/genericclient/generic_stream_service.go | 4 ++++ client/genericclient/stream.go | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/client/genericclient/generic_stream_service.go b/client/genericclient/generic_stream_service.go index a3b995a305..7074207f2d 100644 --- a/client/genericclient/generic_stream_service.go +++ b/client/genericclient/generic_stream_service.go @@ -21,6 +21,10 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) +func StreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo { + return newClientStreamingServiceInfo(g) +} + func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo { readerWriter := g.MessageReaderWriter() if readerWriter == nil { diff --git a/client/genericclient/stream.go b/client/genericclient/stream.go index 05d260f2a6..c568e0b281 100644 --- a/client/genericclient/stream.go +++ b/client/genericclient/stream.go @@ -47,8 +47,7 @@ type BidirectionalStreaming interface { } func NewStreamingClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { - svcInfo := newClientStreamingServiceInfo(g) - return NewStreamingClientWithServiceInfo(destService, g, svcInfo, opts...) + return NewStreamingClientWithServiceInfo(destService, g, StreamingServiceInfo(g), opts...) } func NewStreamingClientWithServiceInfo(destService string, g generic.Generic, svcInfo *serviceinfo.ServiceInfo, opts ...client.Option) (Client, error) { From 2da97233b6fcbb3a623a259bbefcb455b4799b85 Mon Sep 17 00:00:00 2001 From: Guangming Luo Date: Mon, 26 Aug 2024 14:10:57 +0800 Subject: [PATCH 49/70] chore: update bytedance/gopkg to upgrade go/x/net for security (#1508) --- README.md | 2 +- README_cn.md | 2 +- go.mod | 10 +++++----- go.sum | 22 ++++++++++------------ 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index c0b2b63876..4225a6a2da 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ English | [中文](README_cn.md) ![Stars](https://img.shields.io/github/stars/cloudwego/kitex) ![Forks](https://img.shields.io/github/forks/cloudwego/kitex) -Kitex [kaɪt'eks] is a **high-performance** and **strong-extensibility** Golang RPC framework that helps developers build microservices. If the performance and extensibility are the main concerns when you develop microservices, Kitex can be a good choice. +Kitex [kaɪt'eks] is a **high-performance** and **strong-extensibility** Go RPC framework that helps developers build microservices. If the performance and extensibility are the main concerns when you develop microservices, Kitex can be a good choice. ## Basic Features diff --git a/README_cn.md b/README_cn.md index 055681e386..3a2fa09c6f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -11,7 +11,7 @@ ![Stars](https://img.shields.io/github/stars/cloudwego/kitex) ![Forks](https://img.shields.io/github/forks/cloudwego/kitex) -Kitex[kaɪt'eks] 字节跳动内部的 Golang 微服务 RPC 框架,具有**高性能**、**强可扩展**的特点,在字节内部已广泛使用。如今越来越多的微服务选择使用 Golang,如果对微服务性能有要求,又希望定制扩展融入自己的治理体系,Kitex 会是一个不错的选择。 +Kitex[kaɪt'eks] 字节跳动内部的 Go 微服务 RPC 框架,具有**高性能**、**强可扩展**的特点,在字节内部已广泛使用。如今越来越多的微服务选择使用 Go,如果对微服务性能有要求,又希望定制扩展融入自己的治理体系,Kitex 会是一个不错的选择。 ## 框架特点 diff --git a/go.mod b/go.mod index 27d228b424..fcfc4bf970 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/apache/thrift v0.13.0 - github.com/bytedance/gopkg v0.1.1-0.20240822091137-ff3e2edbc319 + github.com/bytedance/gopkg v0.1.1 github.com/bytedance/sonic v1.12.1 github.com/cloudwego/configmanager v0.2.2 github.com/cloudwego/dynamicgo v0.3.0 @@ -20,9 +20,9 @@ require ( github.com/jhump/protoreflect v1.8.2 github.com/json-iterator/go v1.1.12 github.com/tidwall/gjson v1.17.3 - golang.org/x/net v0.17.0 - golang.org/x/sync v0.1.0 - golang.org/x/sys v0.13.0 + golang.org/x/net v0.24.0 + golang.org/x/sync v0.8.0 + golang.org/x/sys v0.19.0 golang.org/x/tools v0.6.0 google.golang.org/genproto v0.0.0-20210513213006-bf773b8c8384 google.golang.org/protobuf v1.28.1 @@ -48,5 +48,5 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.2.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 08de356b34..3ab2eee836 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,8 @@ github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= -github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.1.1-0.20240822091137-ff3e2edbc319 h1:XMLnw5HdHWpmbyiIMWlC7c6GrkHoQHIMZX3vHeYMZiw= -github.com/bytedance/gopkg v0.1.1-0.20240822091137-ff3e2edbc319/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.1 h1:3azzgSkiaw79u24a+w9arfH8OfnQQ4MHUt9lJFREEaE= +github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.12.1 h1:jWl5Qz1fy7X1ioY74WqO0KjAMtAGQs4sYnjiEBiyX24= github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= @@ -164,8 +162,8 @@ golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLd golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -175,8 +173,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -193,8 +191,8 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -202,8 +200,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= From b129365d3dc876390d1328cc3857d9fb44858579 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:19:53 +0800 Subject: [PATCH 50/70] fix(generic): judge business error directly (#1501) --- internal/generic/thrift/binary.go | 35 +++++++++++++++++++++++++++++++ pkg/generic/binarythrift_codec.go | 29 +++++++++++-------------- 2 files changed, 47 insertions(+), 17 deletions(-) create mode 100644 internal/generic/thrift/binary.go diff --git a/internal/generic/thrift/binary.go b/internal/generic/thrift/binary.go new file mode 100644 index 0000000000..d4aa2e3b2b --- /dev/null +++ b/internal/generic/thrift/binary.go @@ -0,0 +1,35 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "context" + "io" + + "github.com/cloudwego/gopkg/protocol/thrift/base" +) + +// WriteBinary implement of MessageWriter +type WriteBinary struct{} + +func NewWriteBinary() *WriteBinary { + return &WriteBinary{} +} + +func (w *WriteBinary) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { + return nil +} diff --git a/pkg/generic/binarythrift_codec.go b/pkg/generic/binarythrift_codec.go index c6aedd5862..eeca8304d3 100644 --- a/pkg/generic/binarythrift_codec.go +++ b/pkg/generic/binarythrift_codec.go @@ -21,15 +21,17 @@ import ( "encoding/binary" "fmt" - gthrift "github.com/cloudwego/gopkg/protocol/thrift" - + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var _ remote.PayloadCodec = &binaryThriftCodec{} +var ( + _ remote.PayloadCodec = &binaryThriftCodec{} + wb = thrift.NewWriteBinary() +) type binaryReqType = []byte @@ -51,22 +53,15 @@ func (c *binaryThriftCodec) Marshal(ctx context.Context, msg remote.Message, out var transBuff []byte var ok bool if msg.RPCRole() == remote.Server { + // Business error only works properly when using TTHeader and HTTP2 transmission protocols + // If there is a business error, data.(*Result).Success will be nil, and an empty payload will be constructed here to return + if msg.RPCInfo().Invocation().BizStatusErr() != nil { + msg.Data().(WithCodec).SetCodec(wb) + return thriftCodec.Marshal(ctx, msg, out) + } gResult := data.(*Result) transBinary := gResult.Success - // handle biz error - if transBinary == nil { - sz := gthrift.Binary.MessageBeginLength(msg.RPCInfo().Invocation().MethodName()) - sz += gthrift.Binary.FieldStopLength() - b, err := out.Malloc(sz) - if err != nil { - return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err)) - } - b = gthrift.Binary.AppendMessageBegin(b[:0], - msg.RPCInfo().Invocation().MethodName(), gthrift.TMessageType(msg.MessageType()), msg.RPCInfo().Invocation().SeqID()) - b = gthrift.Binary.AppendFieldStop(b) - _ = b - return nil - } else if transBuff, ok = transBinary.(binaryReqType); !ok { + if transBuff, ok = transBinary.(binaryReqType); !ok { return perrors.NewProtocolErrorWithMsg("invalid marshal result in rawThriftBinaryCodec: must be []byte") } } else { From 13d0c3e795ebd995cfa7e214fa639176d77f7bfc Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Mon, 26 Aug 2024 19:50:04 +0800 Subject: [PATCH 51/70] feat(tool): support generating multiple handlers for multiple services (#1425) --- tool/cmd/kitex/args/args.go | 1 + tool/internal_pkg/generator/generator.go | 128 +++++++++++++----- tool/internal_pkg/generator/generator_test.go | 2 +- tool/internal_pkg/generator/type.go | 5 + tool/internal_pkg/pluginmode/protoc/plugin.go | 17 ++- .../pluginmode/thriftgo/convertor.go | 4 + .../pluginmode/thriftgo/plugin.go | 14 +- tool/internal_pkg/tpl/main.go | 32 +++++ 8 files changed, 165 insertions(+), 38 deletions(-) diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 4e81845c05..f7f627225f 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -128,6 +128,7 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { "Skip dependency checking.") f.BoolVar(&a.Rapid, "rapid", false, "Use embedded thriftgo.") + f.Var(&a.BuiltinTpl, "tpl", "Specify kitex built-in template.") a.RecordCmd = os.Args a.Version = version diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 7563bbba43..053e4e12a6 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -47,6 +47,9 @@ const ( ExtensionFilename = "extensions.yaml" DefaultThriftPluginTimeLimit = time.Minute + + // built in tpls + MultipleServicesTpl = "multiple_services" ) var ( @@ -141,6 +144,7 @@ type Config struct { NoDependencyCheck bool Rapid bool + BuiltinTpl util.StringSlice // specify the built-in template to use } // Pack packs the Config into a slice of "key=val" strings. @@ -289,6 +293,15 @@ func (c *Config) ApplyExtension() error { return nil } +func (c *Config) IsUsingMultipleServicesTpl() bool { + for _, part := range c.BuiltinTpl { + if part == MultipleServicesTpl { + return true + } + } + return false +} + // NewGenerator . func NewGenerator(config *Config, middlewares []Middleware) Generator { mws := append(globalMiddlewares, middlewares...) @@ -338,11 +351,20 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error }, } if !g.Config.GenerateInvoker { - tasks = append(tasks, &Task{ - Name: MainFileName, - Path: util.JoinPath(g.OutputPath, MainFileName), - Text: tpl.MainTpl, - }) + if !g.Config.IsUsingMultipleServicesTpl() { + tasks = append(tasks, &Task{ + Name: MainFileName, + Path: util.JoinPath(g.OutputPath, MainFileName), + Text: tpl.MainTpl, + }) + } else { + // using multiple services main.go template + tasks = append(tasks, &Task{ + Name: MainFileName, + Path: util.JoinPath(g.OutputPath, MainFileName), + Text: tpl.MainMultipleServicesTpl, + }) + } } for _, t := range tasks { if util.Exists(t.Path) { @@ -360,37 +382,62 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error fs = append(fs, f) } - handlerFilePath := filepath.Join(g.OutputPath, HandlerFileName) - if util.Exists(handlerFilePath) { - comp := newCompleter( - pkg.ServiceInfo.AllMethods(), - handlerFilePath, - pkg.ServiceInfo.ServiceName) - f, err := comp.CompleteMethods() + if !g.Config.IsUsingMultipleServicesTpl() { + f, err := g.generateHandler(pkg, pkg.ServiceInfo, HandlerFileName) if err != nil { - if err == errNoNewMethod { - return fs, nil - } return nil, err } - fs = append(fs, f) - } else { - task := Task{ - Name: HandlerFileName, - Path: handlerFilePath, - Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, + // when there is no new method, f would be nil + if f != nil { + fs = append(fs, f) } - g.setImports(task.Name, pkg) - handle := func(task *Task, pkg *PackageInfo) (*File, error) { - return task.Render(pkg) + } else { + for _, svc := range pkg.Services { + // set the target service + pkg.ServiceInfo = svc + handlerFileName := "handler_" + svc.ServiceName + ".go" + f, err := g.generateHandler(pkg, svc, handlerFileName) + if err != nil { + return nil, err + } + // when there is no new method, f would be nil + if f != nil { + fs = append(fs, f) + } } - f, err := g.chainMWs(handle)(&task, pkg) - if err != nil { + } + return +} + +// generateHandler generates the handler file based on the pkg and the target service +func (g *generator) generateHandler(pkg *PackageInfo, svc *ServiceInfo, handlerFileName string) (*File, error) { + handlerFilePath := filepath.Join(g.OutputPath, handlerFileName) + if util.Exists(handlerFilePath) { + comp := newCompleter( + svc.AllMethods(), + handlerFilePath, + svc.ServiceName) + f, err := comp.CompleteMethods() + if err != nil && err != errNoNewMethod { return nil, err } - fs = append(fs, f) + return f, nil } - return + + task := Task{ + Name: HandlerFileName, + Path: handlerFilePath, + Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, + } + g.setImports(task.Name, pkg) + handle := func(task *Task, pkg *PackageInfo) (*File, error) { + return task.Render(pkg) + } + f, err := g.chainMWs(handle)(&task, pkg) + if err != nil { + return nil, err + } + return f, nil } func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { @@ -416,12 +463,6 @@ func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { Text: tpl.ServerTpl, Ext: ext.ExtendServer, }, - { - Name: InvokerFileName, - Path: util.JoinPath(output, InvokerFileName), - Text: tpl.InvokerTpl, - Ext: ext.ExtendInvoker, - }, { Name: ServiceFileName, Path: util.JoinPath(output, svcPkg+".go"), @@ -429,6 +470,16 @@ func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { }, } + // do not generate invoker.go in service package by default + if g.Config.GenerateInvoker { + tasks = append(tasks, &Task{ + Name: InvokerFileName, + Path: util.JoinPath(output, InvokerFileName), + Text: tpl.InvokerTpl, + Ext: ext.ExtendInvoker, + }) + } + var fs []*File for _, t := range tasks { if err := t.Build(); err != nil { @@ -563,7 +614,14 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { } case MainFileName: pkg.AddImport("log", "log") - pkg.AddImport(pkg.PkgRefName, util.JoinPath(pkg.ImportPath, strings.ToLower(pkg.ServiceName))) + if !g.Config.IsUsingMultipleServicesTpl() { + pkg.AddImport(pkg.PkgRefName, util.JoinPath(pkg.ImportPath, strings.ToLower(pkg.ServiceName))) + } else { + pkg.AddImport("server", "github.com/cloudwego/kitex/server") + for _, svc := range pkg.Services { + pkg.AddImport(svc.RefName, util.JoinPath(pkg.ImportPath, strings.ToLower(svc.ServiceName))) + } + } } } diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index 43f9b1ac42..4c208583a4 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -69,7 +69,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "BuiltinTpl="}, }, } for _, tt := range tests { diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 367ac997e4..444c277410 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -37,6 +37,7 @@ type PackageInfo struct { Namespace string // a dot-separated string for generating service package under kitex_gen Dependencies map[string]string // package name => import path, used for searching imports *ServiceInfo // the target service + Services []*ServiceInfo // all services defined in a IDL for multiple services scenario // the following fields will be filled and used by the generator Codec string @@ -106,6 +107,10 @@ type ServiceInfo struct { Protocol string HandlerReturnKeepResp bool UseThriftReflection bool + // for multiple services scenario, the reference name for the service + RefName string + // identify whether this service would generate a corresponding handler. + GenerateHandler bool } // AllMethods returns all methods that the service have. diff --git a/tool/internal_pkg/pluginmode/protoc/plugin.go b/tool/internal_pkg/pluginmode/protoc/plugin.go index a5bbff5c28..a7fb50fc3f 100644 --- a/tool/internal_pkg/pluginmode/protoc/plugin.go +++ b/tool/internal_pkg/pluginmode/protoc/plugin.go @@ -217,7 +217,19 @@ func (pp *protocPlugin) process(gen *protogen.Plugin) { gen.Error(errors.New("no service defined")) return } - pp.ServiceInfo = pp.Services[len(pp.Services)-1] + if !pp.IsUsingMultipleServicesTpl() { + // if -tpl multiple_services is not set, specify the last service as the target service + pp.ServiceInfo = pp.Services[len(pp.Services)-1] + } else { + var svcs []*generator.ServiceInfo + for _, svc := range pp.Services { + if svc.GenerateHandler { + svc.RefName = "service" + svc.ServiceName + svcs = append(svcs, svc) + } + } + pp.PackageInfo.Services = svcs + } fs, err := pp.kg.GenerateMainPackage(&pp.PackageInfo) if err != nil { pp.err = err @@ -291,6 +303,9 @@ func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.Servi for _, m := range si.Methods { BuildStreaming(m, si.HasStreaming) } + if file.Generate { + si.GenerateHandler = true + } ss = append(ss, si) } // combine service diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index 7b51e67f83..e7d6e9e3ea 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -324,6 +324,9 @@ func (c *converter) convertTypes(req *plugin.Request) error { return fmt.Errorf("%s: makeService '%s': %w", ast.Filename, svc.Name, err) } si.ServiceFilePath = ast.Filename + if ast == req.AST { + si.GenerateHandler = true + } all[ast.Filename] = append(all[ast.Filename], si) c.svc2ast[si] = ast } @@ -379,6 +382,7 @@ func (c *converter) convertTypes(req *plugin.Request) error { Methods: methods, ServiceFilePath: ast.Filename, HasStreaming: hasStreaming, + GenerateHandler: true, } if c.IsHessian2() { diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index a21c342e65..13afd349bc 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -87,7 +87,19 @@ func HandleRequest(req *plugin.Request) *plugin.Response { if len(conv.Services) == 0 { return conv.failResp(errors.New("no service defined in the IDL")) } - conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + if !conv.Config.IsUsingMultipleServicesTpl() { + // if -tpl multiple_services is not set, specify the last service as the target service + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + } else { + var svcs []*generator.ServiceInfo + for _, svc := range conv.Services { + if svc.GenerateHandler { + svc.RefName = "service" + svc.ServiceName + svcs = append(svcs, svc) + } + } + conv.Package.Services = svcs + } fs, err := gen.GenerateMainPackage(&conv.Package) if err != nil { return conv.failResp(err) diff --git a/tool/internal_pkg/tpl/main.go b/tool/internal_pkg/tpl/main.go index b288a102e5..eb471df6b7 100644 --- a/tool/internal_pkg/tpl/main.go +++ b/tool/internal_pkg/tpl/main.go @@ -39,3 +39,35 @@ func main() { } } ` + +var MainMultipleServicesTpl string = `package main + +import ( + {{- range $path, $aliases := .Imports}} + {{- if not $aliases}} + "{{$path}}" + {{- else}} + {{- range $alias, $is := $aliases}} + {{$alias}} "{{$path}}" + {{- end}} + {{- end}} + {{- end}} +) + +func main() { + svr := server.NewServer() + + {{- range $idx, $svc := .Services}} + if err := {{$svc.RefName}}.RegisterService(svr, new({{$svc.ServiceName}}Impl)); err != nil { + panic(err) + } + {{- end}} + + err := svr.Run() + + if err != nil { + log.Println(err.Error()) + } +} + +` From 237fae01442b0551e79c6471dd313043c1ba7fc2 Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Mon, 26 Aug 2024 21:25:45 +0800 Subject: [PATCH 52/70] perf(thrift): encodeBasicThrift write logic didn't use kitex BinaryProtocol (#1511) --- pkg/remote/codec/thrift/thrift.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index e7a9fed13f..ce175342f9 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -211,7 +211,7 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string } _ = thrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seqID) - if err := apache.ThriftWrite(apache.NewDefaultTransport(out), data); err != nil { + if err := apache.ThriftWrite(out, data); err != nil { return err } return nil From d967b72baf12801e792d8e201ba0e17135befbb5 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Tue, 27 Aug 2024 20:48:44 +0800 Subject: [PATCH 53/70] feat(tool): support updating import path for PkgInfo (#1513) --- tool/internal_pkg/generator/generator.go | 2 +- tool/internal_pkg/generator/type.go | 56 ++++++++++++++++++++++-- tool/internal_pkg/generator/type_test.go | 50 +++++++++++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 053e4e12a6..8291ba123a 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -617,7 +617,7 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { if !g.Config.IsUsingMultipleServicesTpl() { pkg.AddImport(pkg.PkgRefName, util.JoinPath(pkg.ImportPath, strings.ToLower(pkg.ServiceName))) } else { - pkg.AddImport("server", "github.com/cloudwego/kitex/server") + pkg.AddImports("server") for _, svc := range pkg.Services { pkg.AddImport(svc.RefName, util.JoinPath(pkg.ImportPath, strings.ToLower(svc.ServiceName))) } diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 444c277410..bdbb92ebf3 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -60,10 +60,7 @@ func (p *PackageInfo) AddImport(pkg, path string) { p.Imports = make(map[string]map[string]bool) } if pkg != "" { - if p.ExternalKitexGen != "" && strings.Contains(path, KitexGenPath) { - parts := strings.Split(path, KitexGenPath) - path = util.JoinPath(p.ExternalKitexGen, parts[len(parts)-1]) - } + path = p.toExternalGenPath(path) if path == pkg { p.Imports[path] = nil } else { @@ -86,6 +83,57 @@ func (p *PackageInfo) AddImports(pkgs ...string) { } } +// UpdateImportPath changed the mapping between alias -> import path +// For instance: +// +// Original import: alias "original_path" +// Invocation: UpdateImport("alias", "new_path") +// New import: alias "new_path" +// +// if pkg == newPath, then alias would be removed in import sentence: +// +// Original import: context "path/to/custom/context" +// Invocation: UpdateImport("context", "context") +// New import: context +func (p *PackageInfo) UpdateImportPath(pkg, newPath string) { + if p.Imports == nil || pkg == "" || newPath == "" { + return + } + + newPath = p.toExternalGenPath(newPath) + var prevPath string +OutLoop: + for path, pkgSet := range p.Imports { + for pkgKey := range pkgSet { + if pkgKey == pkg { + prevPath = path + break OutLoop + } + } + } + if prevPath == "" { + return + } + + delete(p.Imports, prevPath) + if newPath == pkg { // remove the alias + p.Imports[newPath] = nil + } else { // change the path -> alias mapping + p.Imports[newPath] = map[string]bool{ + pkg: true, + } + } +} + +func (p *PackageInfo) toExternalGenPath(path string) string { + if p.ExternalKitexGen == "" || !strings.Contains(path, KitexGenPath) { + return path + } + parts := strings.Split(path, KitexGenPath) + newPath := util.JoinPath(p.ExternalKitexGen, parts[len(parts)-1]) + return newPath +} + // PkgInfo . type PkgInfo struct { PkgName string diff --git a/tool/internal_pkg/generator/type_test.go b/tool/internal_pkg/generator/type_test.go index 16c4e3d290..1070403ef5 100644 --- a/tool/internal_pkg/generator/type_test.go +++ b/tool/internal_pkg/generator/type_test.go @@ -108,3 +108,53 @@ func TestServiceInfo_FixHasStreamingForExtendedService(t *testing.T) { test.Assert(t, s.Base.HasStreaming) }) } + +func TestPkgInfo_UpdateImportPath(t *testing.T) { + testcases := []struct { + desc string + imports map[string]map[string]bool + pkg string + newPath string + expect func(t *testing.T, pkgInfo *PackageInfo, pkg, newPath string) + }{ + { + desc: "update path for same alias", + imports: map[string]map[string]bool{ + "path/to/server": { + "server": true, + }, + }, + pkg: "server", + newPath: "new/path/to/server", + expect: func(t *testing.T, pkgInfo *PackageInfo, pkg, newPath string) { + pkgSet := pkgInfo.Imports[newPath] + test.Assert(t, pkgSet != nil) + test.Assert(t, pkgSet[pkg]) + }, + }, + { + desc: "remove alias", + imports: map[string]map[string]bool{ + "path/to/context": { + "context": true, + }, + }, + pkg: "context", + newPath: "context", + expect: func(t *testing.T, pkgInfo *PackageInfo, pkg, newPath string) { + pkgSet := pkgInfo.Imports[newPath] + test.Assert(t, pkgSet == nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + pkgInfo := &PackageInfo{ + Imports: tc.imports, + } + pkgInfo.UpdateImportPath(tc.pkg, tc.newPath) + tc.expect(t, pkgInfo, tc.pkg, tc.newPath) + }) + } +} From 4aa56335a676e7ee4f910b78ca9e6598733f8f18 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 28 Aug 2024 08:34:02 +0400 Subject: [PATCH 54/70] feat(grpc): server returns cancel reason (#1514) --- pkg/remote/trans/nphttp2/grpc/context.go | 54 +++++++++++++++ pkg/remote/trans/nphttp2/grpc/context_test.go | 43 ++++++++++++ pkg/remote/trans/nphttp2/grpc/http2_server.go | 69 ++++++++++++------- pkg/remote/trans/nphttp2/grpc/transport.go | 18 +++-- .../trans/nphttp2/grpc/transport_test.go | 4 +- 5 files changed, 155 insertions(+), 33 deletions(-) create mode 100644 pkg/remote/trans/nphttp2/grpc/context.go create mode 100644 pkg/remote/trans/nphttp2/grpc/context_test.go diff --git a/pkg/remote/trans/nphttp2/grpc/context.go b/pkg/remote/trans/nphttp2/grpc/context.go new file mode 100644 index 0000000000..e35cdb653a --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/context.go @@ -0,0 +1,54 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "context" + "sync/atomic" +) + +// contextWithCancelReason implements context.Context +// with a cancel func for passing cancel reason +// NOTE: use context.WithCancelCause when go1.20? +type contextWithCancelReason struct { + context.Context + + cancel context.CancelFunc + reason atomic.Value +} + +func (c *contextWithCancelReason) Err() error { + err := c.reason.Load() + if err != nil { + return err.(error) + } + return c.Context.Err() +} + +func (c *contextWithCancelReason) CancelWithReason(reason error) { + if reason != nil { + c.reason.CompareAndSwap(nil, reason) + } + c.cancel() +} + +type cancelWithReason func(reason error) + +func newContextWithCancelReason(ctx context.Context, cancel context.CancelFunc) (context.Context, cancelWithReason) { + ret := &contextWithCancelReason{Context: ctx, cancel: cancel} + return ret, ret.CancelWithReason +} diff --git a/pkg/remote/trans/nphttp2/grpc/context_test.go b/pkg/remote/trans/nphttp2/grpc/context_test.go new file mode 100644 index 0000000000..d731c781fc --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/context_test.go @@ -0,0 +1,43 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestContextWithCancelReason(t *testing.T) { + ctx0, cancel0 := context.WithCancel(context.Background()) + ctx, cancel := newContextWithCancelReason(ctx0, cancel0) + + // cancel contextWithCancelReason + expectErr := errors.New("testing") + cancel(expectErr) + test.Assert(t, ctx0.Err() == context.Canceled) + test.Assert(t, ctx.Err() == expectErr) + + // cancel underlying context + ctx0, cancel0 = context.WithCancel(context.Background()) + ctx, _ = newContextWithCancelReason(ctx0, cancel0) + cancel0() + test.Assert(t, ctx0.Err() == context.Canceled) + test.Assert(t, ctx.Err() == context.Canceled) +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index 515c47ef9b..25af3c4ffc 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -43,6 +43,7 @@ import ( "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" @@ -53,9 +54,19 @@ var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + + // errors used for cancelling stream. + // the code should be codes.Canceled coz it's NOT returned from remote + errConnectionEOF = status.New(codes.Canceled, "transport: connection EOF").Err() + errStreamClosing = status.New(codes.Canceled, "transport: stream is closing").Err() + errMaxStreamsExceeded = status.New(codes.Canceled, "transport: max streams exceeded").Err() + errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err() + errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err() + errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err() ) func init() { @@ -232,7 +243,7 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ defer func() { if err != nil { - t.Close() + t.closeWithErr(err) } }() @@ -281,8 +292,9 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ return t, nil } -// operateHeader takes action on the decoded headers. -func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) { +// operateHeaders takes action on the decoded headers. Returns an error if fatal +// error encountered and transport needs to close, otherwise returns nil. +func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) error { streamID := frame.Header().StreamID state := &decodeState{ serverSide: true, @@ -296,7 +308,7 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f onWrite: func() {}, }) } - return false + return nil } buf := newRecvBuffer() @@ -314,11 +326,13 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f // s is just created by the caller. No lock needed. s.state = streamReadDone } + var cancel context.CancelFunc if state.data.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout) + s.ctx, cancel = context.WithTimeout(t.ctx, state.data.timeout) } else { - s.ctx, s.cancel = context.WithCancel(t.ctx) + s.ctx, cancel = context.WithCancel(t.ctx) } + s.ctx, s.cancel = newContextWithCancelReason(s.ctx, cancel) // Attach the received metadata to the context. if len(state.data.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) @@ -327,8 +341,8 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f t.mu.Lock() if t.state != reachable { t.mu.Unlock() - s.cancel() - return false + s.cancel(errNotReachable) + return nil } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() @@ -338,15 +352,12 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f rstCode: http2.ErrCodeRefusedStream, onWrite: func() {}, }) - s.cancel() - return false + s.cancel(errMaxStreamsExceeded) + return nil } if streamID%2 != 1 || streamID <= t.maxStreamID { t.mu.Unlock() - // illegal gRPC stream id. - klog.CtxErrorf(s.ctx, "transport: http2Server.HandleStreams received an illegal stream id: %v", streamID) - s.cancel() - return true + return fmt.Errorf("received an illegal stream id: %v. headers frame: %+v", streamID, frame) } t.maxStreamID = streamID t.activeStreams[streamID] = s @@ -377,7 +388,7 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f wq: s.wq, }) handle(s) - return false + return nil } // HandleStreams receives incoming streams using the given handler. This is @@ -396,6 +407,9 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. s := t.activeStreams[se.StreamID] t.mu.Unlock() if s != nil { + // it will be codes.Internal error for GRPC + // TODO: map http2.StreamError to status.Error? + s.cancel(err) t.closeStream(s, true, se.Code, false) } else { t.controlBuf.put(&cleanupStream{ @@ -408,17 +422,18 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. continue } if err == io.EOF || err == io.ErrUnexpectedEOF || errors.Is(err, netpoll.ErrEOF) { - t.Close() + t.closeWithErr(errConnectionEOF) return } klog.CtxWarnf(t.ctx, "transport: http2Server.HandleStreams failed to read frame: %v", err) - t.Close() + t.closeWithErr(err) return } switch frame := frame.(type) { case *grpcframe.MetaHeadersFrame: - if t.operateHeaders(frame, handle, traceCtx) { - t.Close() + if err := t.operateHeaders(frame, handle, traceCtx); err != nil { + klog.CtxErrorf(t.ctx, "transport: http2Server.HandleStreams fatal err: %v", err) + t.closeWithErr(err) break } case *grpcframe.DataFrame: @@ -826,7 +841,7 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { // Writing headers checks for this condition. if s.getState() == streamDone { // TODO(mmukhi, dfawley): Should the server write also return io.EOF? - s.cancel() + s.cancel(errStreamClosing) select { case <-t.done: return ErrConnClosing @@ -908,7 +923,7 @@ func (t *http2Server) keepalive() { case <-ageTimer.C: // Close the connection after grace period. klog.Infof("transport: closing server transport due to maximum connection age.") - t.Close() + t.closeWithErr(errMaxAgeClosing) case <-t.done: } return @@ -925,7 +940,7 @@ func (t *http2Server) keepalive() { } if outstandingPing && kpTimeoutLeft <= 0 { klog.Infof("transport: closing server transport due to idleness.") - t.Close() + t.closeWithErr(errIdleClosing) return } if !outstandingPing { @@ -950,6 +965,10 @@ func (t *http2Server) keepalive() { // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. func (t *http2Server) Close() error { + return t.closeWithErr(nil) +} + +func (t *http2Server) closeWithErr(reason error) error { t.mu.Lock() if t.state == closing { t.mu.Unlock() @@ -962,10 +981,12 @@ func (t *http2Server) Close() error { t.controlBuf.finish() close(t.done) err := t.conn.Close() + // Cancel all active streams. for _, s := range streams { - s.cancel() + s.cancel(reason) } + return err } @@ -974,7 +995,7 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. - s.cancel() + s.cancel(nil) // more details about the reason? t.mu.Lock() if _, ok := t.activeStreams[s.id]; ok { diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index 786afb8454..6afdeb7be7 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -233,13 +233,13 @@ const ( // Stream represents an RPC in the transport layer. type Stream struct { id uint32 - st ServerTransport // nil for client side Stream - ct *http2Client // nil for server side Stream - ctx context.Context // the associated context of the stream - cancel context.CancelFunc // always nil for client side Stream - done chan struct{} // closed at the end of stream to unblock writers. On the client side. - ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) - method string // the associated RPC method of the stream + st ServerTransport // nil for client side Stream + ct *http2Client // nil for server side Stream + ctx context.Context // the associated context of the stream + cancel cancelWithReason // always nil for client side Stream + done chan struct{} // closed at the end of stream to unblock writers. On the client side. + ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) + method string // the associated RPC method of the stream recvCompress string sendCompress string buf *recvBuffer @@ -798,6 +798,10 @@ func ContextErr(err error) error { case context.Canceled: return status.New(codes.Canceled, err.Error()).Err() } + statusErr, ok := err.(*status.Error) + if ok { // only returned by contextWithCancelReason + return statusErr + } return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) } diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 7225018ad3..9e80c169ab 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -993,8 +993,8 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { ct.Close() select { case <-ss.Context().Done(): - if ss.Context().Err() != context.Canceled { - t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) + if ss.Context().Err() != errConnectionEOF { + t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), errConnectionEOF) } case <-time.After(3 * time.Second): t.Fatalf("%s", "Failed to cancel the context of the sever side stream.") From b870445fa0e116e9e59a1eaa54ae1d8cad1bad95 Mon Sep 17 00:00:00 2001 From: QihengZhou Date: Wed, 28 Aug 2024 14:20:52 +0800 Subject: [PATCH 55/70] fix(gRPC): pass error when client transport is closed (#1515) --- pkg/remote/trans/nphttp2/grpc/http2_client.go | 54 ++++++++++--------- .../trans/nphttp2/grpc/keepalive_test.go | 28 +++++----- pkg/remote/trans/nphttp2/grpc/transport.go | 2 +- .../trans/nphttp2/grpc/transport_test.go | 39 +++++++------- 4 files changed, 63 insertions(+), 60 deletions(-) diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 109a5538dc..7af2571fdf 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -221,12 +221,14 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, // Send connection preface to server. n, err := t.conn.Write(ClientPreface) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + t.Close(err) + return nil, err } if n != ClientPrefaceLen { - t.Close() - return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, ClientPrefaceLen) + err = connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, ClientPrefaceLen) + t.Close(err) + return nil, err } ss := []http2.Setting{ @@ -237,15 +239,17 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, } err = t.framer.WriteSettings(ss...) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + t.Close(err) + return nil, err } // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { if err := t.framer.WriteWindowUpdate(0, delta); err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) + err = connectionErrorf(true, err, "transport: failed to write window update: %v", err) + t.Close(err) + return nil, err } } @@ -594,7 +598,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This method blocks until the addrConn that initiated this transport is // re-connected. This happens because t.onClose() begins reconnect logic at the // addrConn level and blocks until the addrConn is successfully connected. -func (t *http2Client) Close() error { +func (t *http2Client) Close(err error) error { t.mu.Lock() // Make sure we only Close once. if t.state == closing { @@ -617,12 +621,13 @@ func (t *http2Client) Close() error { t.mu.Unlock() t.controlBuf.finish() t.cancel() - err := t.conn.Close() + cErr := t.conn.Close() + // Notify all active streams. for _, s := range streams { - t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) + t.closeStream(s, err, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) } - return err + return cErr } // GracefulClose sets the state to draining, which prevents new streams from @@ -641,7 +646,7 @@ func (t *http2Client) GracefulClose() { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "no active streams left to process while draining")) return } t.controlBuf.put(&incomingGoAway{}) @@ -886,7 +891,7 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 != 1 { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) return } // A client can receive multiple GoAways from the server (see @@ -904,7 +909,7 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) return } default: @@ -936,7 +941,7 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) } } @@ -1031,10 +1036,8 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.ReadFrame() if err != nil { - // TODO(emma): comment this log temporarily, because when use short connection, 'resource temporarily unavailable' error will happen - // if the log need to be output, connection info should be appended - // klog.Errorf("KITEX: grpc readFrame failed, error=%s", err.Error()) - t.Close() // this kicks off resetTransport, so must be last before return + err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) + t.Close(err) // this kicks off resetTransport, so must be last before return return } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) @@ -1043,7 +1046,8 @@ func (t *http2Client) reader() { } sf, ok := frame.(*grpcframe.SettingsFrame) if !ok { - t.Close() // this kicks off resetTransport, so must be last before return + err = connectionErrorf(true, err, "first frame received is not a setting frame") + t.Close(err) // this kicks off resetTransport, so must be last before return return } t.handleSettings(sf, true) @@ -1076,10 +1080,8 @@ func (t *http2Client) reader() { continue } else { // Transport error. - // TODO(emma): comment this log temporarily, because when use short connection, 'resource temporarily unavailable' error will happen - // if the log need to be output, connection info should be appended - // klog.Errorf("KITEX: grpc readFrame failed, error=%s", err.Error()) - t.Close() + err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) + t.Close(err) return } } @@ -1137,7 +1139,7 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) return } t.mu.Lock() diff --git a/pkg/remote/trans/nphttp2/grpc/keepalive_test.go b/pkg/remote/trans/nphttp2/grpc/keepalive_test.go index 965cc0c462..74ae89c0a0 100644 --- a/pkg/remote/trans/nphttp2/grpc/keepalive_test.go +++ b/pkg/remote/trans/nphttp2/grpc/keepalive_test.go @@ -47,7 +47,7 @@ func TestMaxConnectionIdle(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -85,7 +85,7 @@ func TestMaxConnectionIdleBusyClient(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -162,7 +162,7 @@ func TestKeepaliveServerClosesUnresponsiveClient(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -226,7 +226,7 @@ func TestKeepaliveServerWithResponsiveClient(t *testing.T) { }, }) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -258,7 +258,7 @@ func TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { if client == nil { t.Fatalf("setUpWithNoPingServer failed, return nil client") } - defer client.Close() + defer client.Close(errSelfCloseForTest) conn, ok := <-connCh if !ok { @@ -293,7 +293,7 @@ func TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { if client == nil { t.Fatalf("setUpWithNoPingServer failed, return nil client") } - defer client.Close() + defer client.Close(errSelfCloseForTest) conn, ok := <-connCh if !ok { @@ -326,7 +326,7 @@ func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { if client == nil { t.Fatalf("setUpWithNoPingServer failed, return nil client") } - defer client.Close() + defer client.Close(errSelfCloseForTest) conn, ok := <-connCh if !ok { @@ -362,7 +362,7 @@ func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { }, }) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -400,7 +400,7 @@ func TestKeepaliveClientFrequency(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -444,7 +444,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -487,7 +487,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -536,7 +536,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -569,7 +569,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() @@ -608,7 +608,7 @@ func TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(errSelfCloseForTest) server.stop() }() diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index 6afdeb7be7..f6e0706227 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -646,7 +646,7 @@ type ClientTransport interface { // Close tears down this transport. Once it returns, the transport // should not be accessed any more. The caller must make sure this // is called only once. - Close() error + Close(err error) error // GracefulClose starts to tear down the transport: the transport will stop // accepting new RPCs and NewStream will return error. Once all streams are diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 9e80c169ab..ec98259a52 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -65,6 +65,7 @@ var ( expectedRequestLarge = make([]byte, initialWindowSize*2) expectedResponseLarge = make([]byte, initialWindowSize*2) expectedInvalidHeaderField = "invalid/content-type" + errSelfCloseForTest = errors.New("self-close in test") ) func init() { @@ -515,7 +516,7 @@ func TestInflightStreamClosing(t *testing.T) { serverConfig := &ServerConfig{} server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("self-close in test")) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -584,7 +585,7 @@ func TestClientSendAndReceive(t *testing.T) { if recvErr != io.EOF { t.Fatalf("Error: %v; want ", recvErr) } - ct.Close() + ct.Close(errSelfCloseForTest) server.stop() } @@ -597,7 +598,7 @@ func TestClientErrorNotify(t *testing.T) { }() // ct.reader should detect the error and activate ct.Error(). <-ct.Error() - ct.Close() + ct.Close(nil) } func performOneRPC(ct ClientTransport) { @@ -633,7 +634,7 @@ func TestClientMix(t *testing.T) { }(s) go func(ct ClientTransport) { <-ct.Error() - ct.Close() + ct.Close(errSelfCloseForTest) }(ct) for i := 0; i < 1000; i++ { time.Sleep(1 * time.Millisecond) @@ -671,7 +672,7 @@ func TestLargeMessage(t *testing.T) { }() } wg.Wait() - ct.Close() + ct.Close(errSelfCloseForTest) server.stop() } @@ -687,7 +688,7 @@ func TestLargeMessageWithDelayRead(t *testing.T) { } server, ct := setUpWithOptions(t, 0, sc, delayRead, co) defer server.stop() - defer ct.Close() + defer ct.Close(errSelfCloseForTest) server.mu.Lock() ready := server.ready server.mu.Unlock() @@ -864,7 +865,7 @@ func TestLargeMessageSuspension(t *testing.T) { if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } - ct.Close() + ct.Close(errSelfCloseForTest) server.stop() } @@ -873,7 +874,7 @@ func TestMaxStreams(t *testing.T) { MaxStreams: 1, } server, ct := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) - defer ct.Close() + defer ct.Close(errSelfCloseForTest) defer server.stop() callHdr := &CallHdr{ Host: "localhost", @@ -933,7 +934,7 @@ func TestMaxStreams(t *testing.T) { // Close the first stream created so that the new stream can finally be created. ct.CloseStream(s, nil) <-done - ct.Close() + ct.Close(errSelfCloseForTest) <-ct.writerDone if ct.maxConcurrentStreams != 1 { t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) @@ -990,7 +991,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { sc.mu.Unlock() break } - ct.Close() + ct.Close(errSelfCloseForTest) select { case <-ss.Context().Done(): if ss.Context().Err() != errConnectionEOF { @@ -1010,7 +1011,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) { } server, client := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) defer server.stop() - defer client.Close() + defer client.Close(errSelfCloseForTest) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() @@ -1098,7 +1099,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { } server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer server.stop() - defer client.Close() + defer client.Close(errSelfCloseForTest) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1454,7 +1455,7 @@ func TestEncodingRequiredStatus(t *testing.T) { if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) } - ct.Close() + ct.Close(errSelfCloseForTest) server.stop() } @@ -1475,14 +1476,14 @@ func TestInvalidHeaderField(t *testing.T) { if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { t.Fatalf("Read got error %v, want error with code %v and contains %q", err, codes.Internal, expectedInvalidHeaderField) } - ct.Close() + ct.Close(errSelfCloseForTest) server.stop() } func TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField) defer server.stop() - defer ct.Close() + defer ct.Close(errSelfCloseForTest) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) @@ -1588,7 +1589,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) } server, client := setUpWithOptions(t, 0, sc, pingpong, co) defer server.stop() - defer client.Close() + defer client.Close(errSelfCloseForTest) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1670,7 +1671,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) } // Close down both server and client so that their internals can be read without data // races. - client.Close() + client.Close(errSelfCloseForTest) st.Close() <-st.readerDone <-st.writerDone @@ -1869,7 +1870,7 @@ func TestPingPong1MB(t *testing.T) { func runPingPongTest(t *testing.T, msgSize int) { server, client := setUp(t, 0, 0, pingpong) defer server.stop() - defer client.Close() + defer client.Close(errSelfCloseForTest) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1956,7 +1957,7 @@ func TestHeaderTblSize(t *testing.T) { }() server, ct := setUp(t, 0, math.MaxUint32, normal) - defer ct.Close() + defer ct.Close(errSelfCloseForTest) defer server.stop() ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() From ebd3dfe0ce568238794841dbff53b6c4344de46e Mon Sep 17 00:00:00 2001 From: QihengZhou Date: Wed, 28 Aug 2024 15:53:14 +0800 Subject: [PATCH 56/70] feat: customized payload validator (#1478) --- pkg/consts/ctx.go | 1 + pkg/kerrors/kerrors.go | 2 + pkg/remote/codec/default_codec.go | 110 ++++++++++++- pkg/remote/codec/default_codec_test.go | 173 ++++++++++++++++++++ pkg/remote/codec/header_codec_test.go | 28 +++- pkg/remote/codec/validate.go | 217 +++++++++++++++++++++++++ pkg/remote/codec/validate_test.go | 176 ++++++++++++++++++++ pkg/remote/trans/netpoll/bytebuf.go | 8 + pkg/remote/transmeta/metakey.go | 2 + pkg/stats/event.go | 28 ++-- 10 files changed, 722 insertions(+), 23 deletions(-) create mode 100644 pkg/remote/codec/validate.go create mode 100644 pkg/remote/codec/validate_test.go diff --git a/pkg/consts/ctx.go b/pkg/consts/ctx.go index 30c9c4ee7d..53250b87f2 100644 --- a/pkg/consts/ctx.go +++ b/pkg/consts/ctx.go @@ -19,6 +19,7 @@ package consts // Method key used in context. const ( CtxKeyMethod = "K_METHOD" + CtxKeyLogID = "K_LOGID" ) const ( diff --git a/pkg/kerrors/kerrors.go b/pkg/kerrors/kerrors.go index 30f352b8d9..783658f27e 100644 --- a/pkg/kerrors/kerrors.go +++ b/pkg/kerrors/kerrors.go @@ -46,6 +46,8 @@ var ( ErrRPCFinish = &basicError{"rpc call finished"} // ErrRoute happens when router fail to route this call ErrRoute = &basicError{"rpc route failed"} + // ErrPayloadValidation happens when payload validation failed + ErrPayloadValidation = &basicError{"payload validation error"} ) // More detailed error types diff --git a/pkg/remote/codec/default_codec.go b/pkg/remote/codec/default_codec.go index a4d29c4523..2dde97837e 100644 --- a/pkg/remote/codec/default_codec.go +++ b/pkg/remote/codec/default_codec.go @@ -22,9 +22,12 @@ import ( "fmt" "sync/atomic" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" + netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/retry" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -59,7 +62,7 @@ var ( func NewDefaultCodec() remote.Codec { // No size limit by default return &defaultCodec{ - maxSize: 0, + CodecConfig{MaxSize: 0}, } } @@ -67,13 +70,37 @@ func NewDefaultCodec() remote.Codec { // maxSize is in bytes func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec { return &defaultCodec{ - maxSize: maxSize, + CodecConfig{MaxSize: maxSize}, } } -type defaultCodec struct { +// NewDefaultCodecWithConfig creates the default protocol sniffing codec supporting thrift and protobuf with the input config. +func NewDefaultCodecWithConfig(cfg CodecConfig) remote.Codec { + if cfg.CRC32Check { + // TODO: crc32 has higher priority now. + cfg.PayloadValidator = NewCRC32PayloadValidator() + } + return &defaultCodec{cfg} +} + +// CodecConfig is the config of defaultCodec +type CodecConfig struct { // maxSize limits the max size of the payload - maxSize int + MaxSize int + + // If crc32Check is true, the codec will validate the payload using crc32c. + // Only effective when transport is TTHeader. + // Payload is all the data after TTHeader. + CRC32Check bool + + // PayloadValidator is used to validate payload with customized checksum logic. + // It prepares a value based on payload in sender-side and validates the value in receiver-side. + // It can only be used when ttheader is enabled. + PayloadValidator PayloadValidator +} + +type defaultCodec struct { + CodecConfig } // EncodePayload encode payload @@ -111,6 +138,7 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message return perrors.NewProtocolErrorWithMsg("no buffer allocated for the framed length field") } payloadLen = out.MallocLen() - headerLen + // FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `framedLenField` may not take effect. binary.BigEndian.PutUint32(framedLenField, uint32(payloadLen)) } else if message.ProtocolInfo().CodecType == serviceinfo.Protobuf { return perrors.NewProtocolErrorWithMsg("protobuf just support 'framed' trans proto") @@ -118,16 +146,19 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message if tp&transport.TTHeader == transport.TTHeader { payloadLen = out.MallocLen() - Size32 } - err = checkPayloadSize(payloadLen, c.maxSize) + err = checkPayloadSize(payloadLen, c.MaxSize) return err } // EncodeMetaAndPayload encode meta and payload func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error { - var err error - var totalLenField []byte tp := message.ProtocolInfo().TransProto + if c.PayloadValidator != nil && tp&transport.TTHeader == transport.TTHeader { + return c.encodeMetaAndPayloadWithPayloadValidator(ctx, message, out, me) + } + var err error + var totalLenField []byte // 1. encode header and return totalLenField if needed // totalLenField will be filled after payload encoded if tp&transport.TTHeader == transport.TTHeader { @@ -144,6 +175,7 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote. if totalLenField == nil { return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field") } + // FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `totalLenField` may not take effect. payloadLen := out.MallocLen() - Size32 binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen)) } @@ -176,6 +208,11 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i if flagBuf, err = in.Peek(2 * Size32); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error())) } + if c.PayloadValidator != nil { + if pErr := payloadChecksumValidate(ctx, c.PayloadValidator, in, message); pErr != nil { + return pErr + } + } } else if isMeshHeader(flagBuf) { message.Tags()[remote.MeshHeader] = true // MeshHeader @@ -186,7 +223,7 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("meshHeader read payload first 8 byte failed: %s", err.Error())) } } - return checkPayload(flagBuf, message, in, isTTHeader, c.maxSize) + return checkPayload(flagBuf, message, in, isTTHeader, c.MaxSize) } // DecodePayload decode payload @@ -229,6 +266,55 @@ func (c *defaultCodec) Name() string { return "default" } +// encodeMetaAndPayloadWithPayloadValidator encodes payload and meta with checksum of the payload. +func (c *defaultCodec) encodeMetaAndPayloadWithPayloadValidator(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) (err error) { + writer := netpoll.NewLinkBuffer() + payloadOut := netpolltrans.NewWriterByteBuffer(writer) + defer func() { + payloadOut.Release(err) + }() + + // 1. encode payload and calculate value via payload validator + if err = me.EncodePayload(ctx, message, payloadOut); err != nil { + return err + } + // get the payload from buffer + // use copy api here because the payload will be used as an argument of Generate function in validator + payload, err := getWrittenBytes(writer) + if err != nil { + return err + } + if c.PayloadValidator != nil { + if err = payloadChecksumGenerate(ctx, c.PayloadValidator, payload, message); err != nil { + return err + } + } + // set payload length before encode TTHeader + message.SetPayloadLen(len(payload)) + + // 2. encode header and return totalLenField if needed + totalLenField, err := ttHeaderCodec.encode(ctx, message, out) + if err != nil { + return err + } + + // 3. write payload to the buffer after TTHeader + if ncWriter, ok := out.(remote.NocopyWrite); ok { + err = ncWriter.WriteDirect(payload, 0) + } else { + _, err = out.WriteBinary(payload) + } + + // 4. fill totalLen field for header if needed + // FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `totalLenField` may not take effect. + if totalLenField == nil { + return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field") + } + payloadLen := out.MallocLen() - Size32 + binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen)) + return err +} + // Select to use thrift or protobuf according to the protocol. func (c *defaultCodec) encodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { pCodec, err := remote.GetPayloadCodec(message) @@ -373,3 +459,11 @@ func checkPayloadSize(payloadLen, maxSize int) error { } return nil } + +// getWrittenBytes gets all written bytes from linkbuffer. +func getWrittenBytes(lb *netpoll.LinkBuffer) (buf []byte, err error) { + if err = lb.Flush(); err != nil { + return nil, err + } + return lb.Bytes(), nil +} diff --git a/pkg/remote/codec/default_codec_test.go b/pkg/remote/codec/default_codec_test.go index 8c6367e12b..b08e2bd000 100644 --- a/pkg/remote/codec/default_codec_test.go +++ b/pkg/remote/codec/default_codec_test.go @@ -20,6 +20,8 @@ import ( "context" "encoding/binary" "errors" + "fmt" + "strings" "testing" "github.com/cloudwego/netpoll" @@ -262,6 +264,111 @@ func TestDefaultSizedCodec_Encode_Decode(t *testing.T) { } } +func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) { + remote.PutPayloadCode(serviceinfo.Thrift, mpc) + + dc := NewDefaultCodecWithConfig(CodecConfig{CRC32Check: true}) + ctx := context.Background() + intKVInfo := prepareIntKVInfo() + strKVInfo := prepareStrKVInfo() + payloadLen := 32 * 1024 + sendMsg := initClientSendMsg(transport.TTHeaderFramed, payloadLen) + sendMsg.TransInfo().PutTransIntInfo(intKVInfo) + sendMsg.TransInfo().PutTransStrInfo(strKVInfo) + + // test encode err + badOut := netpolltrans.NewReaderByteBuffer(netpoll.NewLinkBuffer()) + err := dc.Encode(ctx, sendMsg, badOut) + test.Assert(t, err != nil) + + // encode with netpollBytebuffer + writer := netpoll.NewLinkBuffer() + npBuffer := netpolltrans.NewReaderWriterByteBuffer(writer) + err = dc.Encode(ctx, sendMsg, npBuffer) + test.Assert(t, err == nil, err) + + // decode, succeed + recvMsg := initServerRecvMsg() + buf, err := getWrittenBytes(writer) + test.Assert(t, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = dc.Decode(ctx, recvMsg, in) + test.Assert(t, err == nil, err) + intKVInfoRecv := recvMsg.TransInfo().TransIntInfo() + strKVInfoRecv := recvMsg.TransInfo().TransStrInfo() + test.DeepEqual(t, intKVInfoRecv, intKVInfo) + test.DeepEqual(t, strKVInfoRecv, strKVInfo) + test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID()) + + // decode, crc32c check failed + test.Assert(t, err == nil, err) + bufLen := len(buf) + modifiedBuf := make([]byte, bufLen) + copy(modifiedBuf, buf) + for i := bufLen - 1; i > bufLen-10; i-- { + modifiedBuf[i] = 123 + } + in = remote.NewReaderBuffer(modifiedBuf) + err = dc.Decode(ctx, recvMsg, in) + test.Assert(t, err != nil, err) +} + +func TestDefaultCodecWithCustomizedValidator(t *testing.T) { + remote.PutPayloadCode(serviceinfo.Thrift, mpc) + + dc := NewDefaultCodecWithConfig(CodecConfig{PayloadValidator: &mockPayloadValidator{}}) + ctx := context.Background() + intKVInfo := prepareIntKVInfo() + strKVInfo := prepareStrKVInfo() + payloadLen := 32 * 1024 + sendMsg := initClientSendMsg(transport.TTHeaderFramed, payloadLen) + sendMsg.TransInfo().PutTransIntInfo(intKVInfo) + sendMsg.TransInfo().PutTransStrInfo(strKVInfo) + + // test encode err + badOut := netpolltrans.NewReaderByteBuffer(netpoll.NewLinkBuffer()) + err := dc.Encode(ctx, sendMsg, badOut) + test.Assert(t, err != nil) + + // test encode err because of limit + exceedLimitCtx := context.WithValue(ctx, mockExceedLimitKey, "true") + npBuffer := netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer()) + err = dc.Encode(exceedLimitCtx, sendMsg, npBuffer) + test.Assert(t, err != nil, err) + test.Assert(t, strings.Contains(err.Error(), "limit"), err) + + // encode with netpollBytebuffer + writer := netpoll.NewLinkBuffer() + npBuffer = netpolltrans.NewReaderWriterByteBuffer(writer) + err = dc.Encode(ctx, sendMsg, npBuffer) + test.Assert(t, err == nil, err) + + // decode, succeed + recvMsg := initServerRecvMsg() + buf, err := getWrittenBytes(writer) + test.Assert(t, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = dc.Decode(ctx, recvMsg, in) + test.Assert(t, err == nil, err) + intKVInfoRecv := recvMsg.TransInfo().TransIntInfo() + strKVInfoRecv := recvMsg.TransInfo().TransStrInfo() + test.DeepEqual(t, intKVInfoRecv, intKVInfo) + test.DeepEqual(t, strKVInfoRecv, strKVInfo) + test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID()) + + // decode, check failed + test.Assert(t, err == nil, err) + bufLen := len(buf) + modifiedBuf := make([]byte, bufLen) + copy(modifiedBuf, buf) + for i := bufLen - 1; i > bufLen-10; i-- { + modifiedBuf[i] = 123 + } + in = remote.NewReaderBuffer(modifiedBuf) + err = dc.Decode(ctx, recvMsg, in) + test.Assert(t, err != nil, err) +} + func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) { for _, tb := range transportBuffers { t.Run(tb.Name, func(t *testing.T) { @@ -295,6 +402,43 @@ func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) { } } +func BenchmarkDefaultEncodeDecode(b *testing.B) { + ctx := context.Background() + remote.PutPayloadCode(serviceinfo.Thrift, mpc) + type factory func() remote.Codec + testCases := map[string]factory{"normal": NewDefaultCodec, "crc32c": func() remote.Codec { return NewDefaultCodecWithConfig(CodecConfig{CRC32Check: true}) }} + + for name, f := range testCases { + b.Run(name, func(b *testing.B) { + msgLen := 1 + for i := 0; i < 6; i++ { + b.ReportAllocs() + b.ResetTimer() + b.Run(fmt.Sprintf("payload-%d", msgLen), func(b *testing.B) { + for j := 0; j < b.N; j++ { + codec := f() + sendMsg := initClientSendMsg(transport.TTHeader, msgLen) + // encode + writer := netpoll.NewLinkBuffer() + out := netpolltrans.NewWriterByteBuffer(writer) + err := codec.Encode(ctx, sendMsg, out) + test.Assert(b, err == nil, err) + + // decode + recvMsg := initServerRecvMsgWithMockMsg() + buf, err := getWrittenBytes(writer) + test.Assert(b, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = codec.Decode(ctx, recvMsg, in) + test.Assert(b, err == nil, err) + } + }) + msgLen *= 10 + } + }) + } +} + var mpc remote.PayloadCodec = mockPayloadCodec{} type mockPayloadCodec struct{} @@ -303,6 +447,23 @@ func (m mockPayloadCodec) Marshal(ctx context.Context, message remote.Message, o WriteUint32(ThriftV1Magic+uint32(message.MessageType()), out) WriteString(message.RPCInfo().Invocation().MethodName(), out) WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out) + var ( + dataLen uint32 + dataStr string + ) + // write data + if data := message.Data(); data != nil { + if mm, ok := data.(*mockMsg); ok { + if len(mm.msg) != 0 { + dataStr = mm.msg + dataLen = uint32(len(mm.msg)) + } + } + } + WriteUint32(dataLen, out) + if dataLen > 0 { + WriteString(dataStr, out) + } return nil } @@ -333,6 +494,18 @@ func (m mockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message, if err = SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) { return err } + // read data + dataLen, err := PeekUint32(in) + if err != nil { + return err + } + if dataLen == 0 { + // no data + return nil + } + if _, _, err = ReadString(in); err != nil { + return err + } return nil } diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index af60e4c91c..fb4e8ad051 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -313,6 +313,15 @@ var ( rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) ) +type mockMsg struct { + msg string +} + +func initServerRecvMsgWithMockMsg() remote.Message { + req := &mockMsg{} + return remote.NewMessage(req, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Call, remote.Server) +} + func initServerRecvMsg() remote.Message { svcInfo := mocks.ServiceInfo() svcSearcher := mocksremote.NewDefaultSvcSearcher() @@ -320,23 +329,32 @@ func initServerRecvMsg() remote.Message { return msg } -func initClientSendMsg(tp transport.Protocol) remote.Message { - var req interface{} +func initClientSendMsg(tp transport.Protocol, payloadLen ...int) remote.Message { + req := &mockMsg{} + if len(payloadLen) != 0 { + req.msg = string(make([]byte, payloadLen[0])) + } + svcInfo := mocks.ServiceInfo() + mi := svcInfo.MethodInfo(mockCliRPCInfo.Invocation().MethodName()) + mi.NewArgs() msg := remote.NewMessage(req, svcInfo, mockCliRPCInfo, remote.Call, remote.Client) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) return msg } -func initServerSendMsg(tp transport.Protocol) remote.Message { - var resp interface{} +func initServerSendMsg(tp transport.Protocol, payloadLen ...int) remote.Message { + resp := &mockMsg{} + if len(payloadLen) != 0 { + resp.msg = string(make([]byte, payloadLen[0])) + } msg := remote.NewMessage(resp, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Reply, remote.Server) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, mocks.ServiceInfo().PayloadCodec)) return msg } func initClientRecvMsg() remote.Message { - var resp interface{} + resp := &mockMsg{} svcInfo := mocks.ServiceInfo() msg := remote.NewMessage(resp, svcInfo, mockCliRPCInfo, remote.Reply, remote.Client) return msg diff --git a/pkg/remote/codec/validate.go b/pkg/remote/codec/validate.go new file mode 100644 index 0000000000..e63ca88d58 --- /dev/null +++ b/pkg/remote/codec/validate.go @@ -0,0 +1,217 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package codec + +import ( + "context" + "encoding/binary" + "encoding/hex" + "fmt" + "hash/crc32" + "sync" + + "github.com/cloudwego/kitex/pkg/consts" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/stats" +) + +const ( + PayloadValidatorPrefix = "PV_" + maxPayloadChecksumLength = 4096 // maximum 4k +) + +// PayloadValidator is the interface for validating the payload of RPC requests, which allows customized Checksum function. +type PayloadValidator interface { + // Key returns a key for your validator, which will be the key in ttheader + Key(ctx context.Context) string + + // Generate generates the checksum of the payload. + // The value will not be set to the request header if "need" is false. + // DO NOT modify the input payload since it might be obtained by nocopy API from the underlying buffer. + Generate(ctx context.Context, outboundPayload []byte) (need bool, checksum string, err error) + + // Validate validates the input payload with the attached checksum. + // Return pass if validation succeed, or return error. + // DO NOT modify the input payload since it might be obtained by nocopy API from the underlying buffer. + Validate(ctx context.Context, expectedValue string, inboundPayload []byte) (pass bool, err error) +} + +func getValidatorKey(ctx context.Context, p PayloadValidator) string { + if _, ok := p.(*crcPayloadValidator); ok { + return p.Key(ctx) + } + key := p.Key(ctx) + return PayloadValidatorPrefix + key +} + +func payloadChecksumGenerate(ctx context.Context, pv PayloadValidator, outboundPayload []byte, message remote.Message) (err error) { + rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumGenerateStart, nil) + defer func() { + rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumGenerateFinish, err) + }() + + need, value, pErr := pv.Generate(ctx, outboundPayload) + if pErr != nil { + err = kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("generate failed, err=%v", pErr)) + return err + } + if need { + if len(value) > maxPayloadChecksumLength { + err = kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("payload checksum value exceeds the limit, actual length=%d, limit=%d", len(value), maxPayloadChecksumLength)) + return err + } + key := getValidatorKey(ctx, pv) + strInfo := message.TransInfo().TransStrInfo() + if strInfo != nil { + strInfo[key] = value + } + } + return nil +} + +func payloadChecksumValidate(ctx context.Context, pv PayloadValidator, in remote.ByteBuffer, message remote.Message) (err error) { + rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumValidateStart, nil) + defer func() { + rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumValidateFinish, err) + }() + + // this return ctx can only be used in Validate part since Decode has no return argument for context + ctx = fillRPCInfoBeforeValidate(ctx, message) + + // get key and value + key := getValidatorKey(ctx, pv) + strInfo := message.TransInfo().TransStrInfo() + if strInfo == nil { + return nil + } + expectedValue := strInfo[key] + payloadLen := message.PayloadLen() // total length + payload, err := in.Peek(payloadLen) + if err != nil { + return err + } + + // validate + pass, err := pv.Validate(ctx, expectedValue, payload) + if err != nil { + return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("validation failed, err=%v", err)) + } + if !pass { + return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("validation failed")) + } + return nil +} + +// fillRPCInfoBeforeValidate reads header and set into the RPCInfo, which allows Validate() to use RPCInfo. +func fillRPCInfoBeforeValidate(ctx context.Context, message remote.Message) context.Context { + if message.RPCRole() != remote.Server { + // only fill when server-side reading the request header + // TODO: client-side can read from the response header + return ctx + } + ri := message.RPCInfo() + if ri == nil { + return ctx + } + transInfo := message.TransInfo() + if transInfo == nil { + return ctx + } + intInfo := transInfo.TransIntInfo() + if intInfo == nil { + return ctx + } + from := rpcinfo.AsMutableEndpointInfo(ri.From()) + if from != nil { + if v := intInfo[transmeta.FromService]; v != "" { + from.SetServiceName(v) + } + if v := intInfo[transmeta.FromMethod]; v != "" { + from.SetMethod(v) + } + } + to := rpcinfo.AsMutableEndpointInfo(ri.To()) + if to != nil { + // server-side reads "to_method" from ttheader since "method" is set in thrift payload, which has not been unmarshalled + if v := intInfo[transmeta.ToMethod]; v != "" { + to.SetMethod(v) + } + if v := intInfo[transmeta.ToService]; v != "" { + to.SetServiceName(v) + } + } + if logid := intInfo[transmeta.LogID]; logid != "" { + ctx = context.WithValue(ctx, consts.CtxKeyLogID, logid) + } + return ctx +} + +// NewCRC32PayloadValidator returns a new crcPayloadValidator +func NewCRC32PayloadValidator() PayloadValidator { + crc32TableOnce.Do(func() { + crc32cTable = crc32.MakeTable(crc32.Castagnoli) + }) + return &crcPayloadValidator{} +} + +type crcPayloadValidator struct{} + +var _ PayloadValidator = &crcPayloadValidator{} + +func (p *crcPayloadValidator) Key(ctx context.Context) string { + return transmeta.HeaderCRC32C +} + +func (p *crcPayloadValidator) Generate(ctx context.Context, outPayload []byte) (need bool, value string, err error) { + return true, getCRC32C(outPayload), nil +} + +func (p *crcPayloadValidator) Validate(ctx context.Context, expectedValue string, inputPayload []byte) (pass bool, err error) { + if expectedValue == "" { + // If the expectedValue parsed from TTHeader is empty, it means that the checksum was not set on sender-side + // return true in this case + return true, nil + } + realValue := getCRC32C(inputPayload) + if realValue != expectedValue { + return false, perrors.NewProtocolErrorWithType(perrors.InvalidData, expectedValue) + } + return true, nil +} + +// crc32cTable is used for crc32c check +var ( + crc32cTable *crc32.Table + crc32TableOnce sync.Once +) + +// getCRC32C calculates the crc32c checksum of the input bytes. +// the checksum will be converted into big-endian format and encoded into hex string. +func getCRC32C(payload []byte) string { + if crc32cTable == nil { + return "" + } + csb := make([]byte, Size32) + var checksum uint32 + checksum = crc32.Update(checksum, crc32cTable, payload) + binary.BigEndian.PutUint32(csb, checksum) + return hex.EncodeToString(csb) +} diff --git a/pkg/remote/codec/validate_test.go b/pkg/remote/codec/validate_test.go new file mode 100644 index 0000000000..b17378b4f1 --- /dev/null +++ b/pkg/remote/codec/validate_test.go @@ -0,0 +1,176 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package codec + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/bytedance/gopkg/util/xxhash3" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/transport" +) + +var _ PayloadValidator = &mockPayloadValidator{} + +type mockPayloadValidator struct{} + +const ( + mockGenerateSkipKey = "mockGenerateSkip" + mockGenerateErrorKey = "mockGenerateError" + mockExceedLimitKey = "mockExceedLimit" +) + +func (m *mockPayloadValidator) Key(ctx context.Context) string { + return "mockValidator" +} + +func (m *mockPayloadValidator) Generate(ctx context.Context, outPayload []byte) (need bool, value string, err error) { + if l := ctx.Value(mockGenerateSkipKey); l != nil { + return false, "", nil + } + if l := ctx.Value(mockExceedLimitKey); l != nil { + // return value with length exceeding the limit + return true, string(make([]byte, maxPayloadChecksumLength+1)), nil + } + if l := ctx.Value(mockGenerateErrorKey); l != nil { + return false, "", errors.New("mockGenerateError") + } + hash := xxhash3.Hash(outPayload) + return true, strconv.FormatInt(int64(hash), 10), nil +} + +func (m *mockPayloadValidator) Validate(ctx context.Context, expectedValue string, inputPayload []byte) (pass bool, err error) { + _, value, err := m.Generate(ctx, inputPayload) + if err != nil { + return false, err + } + return value == expectedValue, nil +} + +func TestPayloadValidator(t *testing.T) { + p := &mockPayloadValidator{} + payload := preparePayload() + + need, value, err := p.Generate(context.Background(), payload) + test.Assert(t, err == nil, err) + test.Assert(t, need) + + pass, err := p.Validate(context.Background(), value, payload) + test.Assert(t, err == nil, err) + test.Assert(t, pass, true) +} + +func TestPayloadChecksumGenerate(t *testing.T) { + payload := preparePayload() + pv := &mockPayloadValidator{} + + // success + message := initClientSendMsg(transport.TTHeader) + strInfo := message.TransInfo().TransStrInfo() + ctx := context.Background() + err := payloadChecksumGenerate(ctx, pv, payload, message) + test.Assert(t, err == nil, err) + test.Assert(t, len(strInfo) != 0) + test.Assert(t, strInfo[getValidatorKey(ctx, pv)] != "") + + // success, no need to generate + message = initClientSendMsg(transport.TTHeader) + strInfo = message.TransInfo().TransStrInfo() + ctx = context.WithValue(context.Background(), mockGenerateSkipKey, "true") + err = payloadChecksumGenerate(ctx, pv, payload, message) + test.Assert(t, err == nil, err) + test.Assert(t, len(strInfo) == 0) // no checksum in strinfo + + // failed, generate error + message = initClientSendMsg(transport.TTHeader) + ctx = context.WithValue(context.Background(), mockGenerateErrorKey, "true") + err = payloadChecksumGenerate(ctx, pv, payload, message) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(err, kerrors.ErrPayloadValidation)) + + // failed, exceed limit + message = initClientSendMsg(transport.TTHeader) + ctx = context.WithValue(context.Background(), mockExceedLimitKey, "true") + err = payloadChecksumGenerate(ctx, pv, payload, message) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(err, kerrors.ErrPayloadValidation)) +} + +func TestPayloadChecksumValidate(t *testing.T) { + // prepare + payload := preparePayload() + pv := &mockPayloadValidator{} + sendMsg := initClientSendMsg(transport.TTHeader) + ctx := context.Background() + err := payloadChecksumGenerate(ctx, pv, payload, sendMsg) + test.Assert(t, err == nil, err) + + // success + in := remote.NewReaderBuffer(payload) + message := initClientRecvMsg() + message.TransInfo().PutTransStrInfo(sendMsg.TransInfo().TransStrInfo()) // put header strinfo + message.SetPayloadLen(len(payload)) + err = payloadChecksumValidate(ctx, pv, in, message) + test.Assert(t, err == nil, err) + + // validate failed, checksum validation error + in = remote.NewReaderBuffer(payload) + message = initClientRecvMsg() + // don't put header strinfo + message.SetPayloadLen(len(payload)) + err = payloadChecksumValidate(context.Background(), pv, in, message) + test.Assert(t, err != nil) +} + +func TestCRCPayloadValidator(t *testing.T) { + // prepare + payload := preparePayload() + p := NewCRC32PayloadValidator() + + // success + ctx := context.Background() + need, value, err := p.Generate(ctx, payload) + test.Assert(t, err == nil, err) + test.Assert(t, need) + pass, err := p.Validate(ctx, value, payload) + test.Assert(t, err == nil, err) + test.Assert(t, pass == true) + + // failure, checksum mismatches + pass, err = p.Validate(ctx, value+"0", payload) + test.Assert(t, err != nil, err) + test.Assert(t, pass == false) + + // success when value is empty string, which means no checksum from sender + pass, err = p.Validate(ctx, "", payload) + test.Assert(t, err == nil, err) + test.Assert(t, pass == true) +} + +func preparePayload() []byte { + payload := make([]byte, 1024) + for i := 0; i < len(payload); i++ { + payload[i] = byte(i) + } + return payload +} diff --git a/pkg/remote/trans/netpoll/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf.go index 80612c7e97..4a6bb3b5ee 100644 --- a/pkg/remote/trans/netpoll/bytebuf.go +++ b/pkg/remote/trans/netpoll/bytebuf.go @@ -249,3 +249,11 @@ func (b *netpollByteBuffer) zero() { b.status = 0 b.readSize = 0 } + +// GetWrittenBytes gets all written bytes from linkbuffer. +func GetWrittenBytes(lb *netpoll.LinkBuffer) (buf []byte, err error) { + if err = lb.Flush(); err != nil { + return nil, err + } + return lb.Bytes(), nil +} diff --git a/pkg/remote/transmeta/metakey.go b/pkg/remote/transmeta/metakey.go index a2ad43ce3d..f1a43afbd3 100644 --- a/pkg/remote/transmeta/metakey.go +++ b/pkg/remote/transmeta/metakey.go @@ -63,6 +63,8 @@ const ( // the connection peer will shutdown later,so it send back the header to tell client to close the connection. HeaderConnectionReadyToReset = "crrst" HeaderProcessAtTime = "K_ProcessAtTime" + // HeaderCRC32C is used to store the crc32c checksum of payload + HeaderCRC32C = "crc32c" ) // key of acl token diff --git a/pkg/stats/event.go b/pkg/stats/event.go index 86b1ae8a65..3ac6f71270 100644 --- a/pkg/stats/event.go +++ b/pkg/stats/event.go @@ -72,6 +72,10 @@ const ( writeFinish streamRecv streamSend + checksumGenerateStart + checksumGenerateFinish + checksumValidateStart + checksumValidateFinish // NOTE: add new events before this line predefinedEventNum @@ -82,16 +86,20 @@ var ( RPCStart = newEvent(rpcStart, LevelBase) RPCFinish = newEvent(rpcFinish, LevelBase) - ServerHandleStart = newEvent(serverHandleStart, LevelDetailed) - ServerHandleFinish = newEvent(serverHandleFinish, LevelDetailed) - ClientConnStart = newEvent(clientConnStart, LevelDetailed) - ClientConnFinish = newEvent(clientConnFinish, LevelDetailed) - ReadStart = newEvent(readStart, LevelDetailed) - ReadFinish = newEvent(readFinish, LevelDetailed) - WaitReadStart = newEvent(waitReadStart, LevelDetailed) - WaitReadFinish = newEvent(waitReadFinish, LevelDetailed) - WriteStart = newEvent(writeStart, LevelDetailed) - WriteFinish = newEvent(writeFinish, LevelDetailed) + ServerHandleStart = newEvent(serverHandleStart, LevelDetailed) + ServerHandleFinish = newEvent(serverHandleFinish, LevelDetailed) + ClientConnStart = newEvent(clientConnStart, LevelDetailed) + ClientConnFinish = newEvent(clientConnFinish, LevelDetailed) + ReadStart = newEvent(readStart, LevelDetailed) + ReadFinish = newEvent(readFinish, LevelDetailed) + WaitReadStart = newEvent(waitReadStart, LevelDetailed) + WaitReadFinish = newEvent(waitReadFinish, LevelDetailed) + WriteStart = newEvent(writeStart, LevelDetailed) + WriteFinish = newEvent(writeFinish, LevelDetailed) + ChecksumValidateStart = newEvent(checksumValidateStart, LevelDetailed) + ChecksumValidateFinish = newEvent(checksumValidateFinish, LevelDetailed) + ChecksumGenerateStart = newEvent(checksumGenerateStart, LevelDetailed) + ChecksumGenerateFinish = newEvent(checksumGenerateFinish, LevelDetailed) // Streaming Events StreamRecv = newEvent(streamRecv, LevelDetailed) From 633f72c97968c026ff97cc06c4cff075dcd45c6f Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Wed, 28 Aug 2024 17:07:05 +0800 Subject: [PATCH 57/70] fix(streaming): resolve ctx diverge in server-side streaming (#1471) --- internal/mocks/serviceinfo.go | 1 + pkg/remote/trans/nphttp2/grpc/transport.go | 4 +- pkg/remote/trans/nphttp2/server_conn.go | 21 +- pkg/remote/trans/nphttp2/server_conn_test.go | 99 ++++++- pkg/remote/trans/nphttp2/server_handler.go | 257 +++++++++--------- .../trans/nphttp2/server_handler_test.go | 2 +- pkg/remote/trans/nphttp2/stream_test.go | 3 +- server/invoke.go | 1 + server/middlewares.go | 20 ++ server/option_test.go | 5 - server/server.go | 27 +- server/server_test.go | 209 ++++++++++++++ server/stream.go | 12 + 13 files changed, 512 insertions(+), 149 deletions(-) diff --git a/internal/mocks/serviceinfo.go b/internal/mocks/serviceinfo.go index a2f82c185c..4782e0d1e0 100644 --- a/internal/mocks/serviceinfo.go +++ b/internal/mocks/serviceinfo.go @@ -36,6 +36,7 @@ const ( MockExceptionMethod string = "mockException" MockErrorMethod string = "mockError" MockOnewayMethod string = "mockOneway" + MockStreamingMethod string = "mockStreaming" ) // ServiceInfo return mock serviceInfo diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index f6e0706227..e2b5f0ef1f 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -485,7 +485,7 @@ func StreamWrite(s *Stream, buffer *bytes.Buffer) { } // CreateStream only used for unit test. Create an independent stream out of http2client / http2server -func CreateStream(id uint32, requestRead func(i int)) *Stream { +func CreateStream(ctx context.Context, id uint32, requestRead func(i int), method string) *Stream { recvBuffer := newRecvBuffer() trReader := &transportReader{ reader: &recvBufferReader{ @@ -499,6 +499,8 @@ func CreateStream(id uint32, requestRead func(i int)) *Stream { stream := &Stream{ id: id, + ctx: ctx, + method: method, buf: recvBuffer, trReader: trReader, wq: newWriteQuota(defaultWriteQuota, nil), diff --git a/pkg/remote/trans/nphttp2/server_conn.go b/pkg/remote/trans/nphttp2/server_conn.go index 271c3358c7..51c5003fc5 100644 --- a/pkg/remote/trans/nphttp2/server_conn.go +++ b/pkg/remote/trans/nphttp2/server_conn.go @@ -30,6 +30,8 @@ import ( "github.com/cloudwego/kitex/pkg/streaming" ) +type serverConnKey struct{} + type serverConn struct { tr grpc.ServerTransport s *grpc.Stream @@ -61,25 +63,12 @@ func (c *serverConn) ReadFrame() (hdr, data []byte, err error) { // GetServerConn gets the GRPC Connection from server stream. // This function is only used in server handler for grpc unknown handler proxy: https://www.cloudwego.io/docs/kitex/tutorials/advanced-feature/grpcproxy/ -// And the input stream type should always be streamWithMiddleware. func GetServerConn(st streaming.Stream) (GRPCConn, error) { - mwStream, ok := st.(*streamWithMiddleware) - if !ok { - return nil, status.Errorf(codes.Internal, "failed to get streamWithMiddleware") - } - - serverStream, ok := mwStream.Stream.(*stream) - - if !ok { - // err! - return nil, status.Errorf(codes.Internal, "failed to get server conn from server stream.") - } - grpcServerConn, ok := serverStream.conn.(GRPCConn) + rawStream, ok := st.Context().Value(serverConnKey{}).(*stream) if !ok { - // err! - return nil, status.Errorf(codes.Internal, "failed to trans conn to grpc conn.") + return nil, status.Errorf(codes.Internal, "the ctx of Stream is not provided by Kitex Server") } - return grpcServerConn, nil + return rawStream.conn.(GRPCConn), nil } // impl net.Conn diff --git a/pkg/remote/trans/nphttp2/server_conn_test.go b/pkg/remote/trans/nphttp2/server_conn_test.go index 7719539c3d..1b0d11560f 100644 --- a/pkg/remote/trans/nphttp2/server_conn_test.go +++ b/pkg/remote/trans/nphttp2/server_conn_test.go @@ -17,11 +17,19 @@ package nphttp2 import ( + "context" + "net" "testing" "time" + "github.com/cloudwego/kitex/internal/mocks" + mock_remote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" ) func TestServerConn(t *testing.T) { @@ -30,7 +38,7 @@ func TestServerConn(t *testing.T) { npConn.mockSettingFrame() tr, err := newMockServerTransport(npConn) test.Assert(t, err == nil, err) - s := grpc.CreateStream(1, func(i int) {}) + s := grpc.CreateStream(context.Background(), 1, func(i int) {}, "") serverConn := newServerConn(tr, s) defer serverConn.Close() @@ -74,3 +82,92 @@ func TestServerConn(t *testing.T) { err = serverConn.SetWriteDeadline(time.Now()) test.Assert(t, err == nil, err) } + +type customStream struct { + streaming.Stream + ctx context.Context +} + +func (s *customStream) Context() context.Context { + return s.ctx +} + +func TestGetServerConn(t *testing.T) { + testcases := []struct { + desc string + ep endpoint.Endpoint + }{ + { + desc: "normal scenario", + ep: func(ctx context.Context, req, resp interface{}) (err error) { + arg, ok := req.(*streaming.Args) + test.Assert(t, ok) + test.Assert(t, arg.Stream != nil) + // ensure that the Stream exposed to users makes GetServerConn worked + _, err = GetServerConn(arg.Stream) + test.Assert(t, err == nil, err) + return nil + }, + }, + { + desc: "users wrap Stream and rewrite Context() method with the original ctx", + ep: func(ctx context.Context, req, resp interface{}) (err error) { + arg, ok := req.(*streaming.Args) + test.Assert(t, ok) + test.Assert(t, arg.Stream != nil) + cs := &customStream{ + Stream: arg.Stream, + ctx: context.WithValue(ctx, "key", "val"), + } + _, err = GetServerConn(cs) + test.Assert(t, err == nil) + return nil + }, + }, + { + desc: "users wrap Stream and rewrite Context() method without the original ctx", + ep: func(ctx context.Context, req, resp interface{}) (err error) { + arg, ok := req.(*streaming.Args) + test.Assert(t, ok) + test.Assert(t, arg.Stream != nil) + cs := &customStream{ + Stream: arg.Stream, + ctx: context.Background(), + } + _, err = GetServerConn(cs) + test.Assert(t, err != nil) + return nil + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + transHdl, err := newSvrTransHandler(&remote.ServerOption{ + SvcSearcher: mock_remote.NewDefaultSvcSearcher(), + GRPCCfg: grpc.DefaultServerConfig(), + InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + return newMockRPCInfo() + }, + TracerCtl: &rpcinfo.TraceController{}, + }) + test.Assert(t, err == nil, err) + transHdl.inkHdlFunc = tc.ep + + npConn := newMockNpConn(mockAddr0) + npConn.mockSettingFrame() + ctx := context.Background() + ctx, err = transHdl.OnActive(ctx, npConn) + test.Assert(t, err == nil, err) + svrTrans, ok := ctx.Value(ctxKeySvrTransport).(*SvrTrans) + test.Assert(t, ok) + test.Assert(t, svrTrans.tr != nil) + defer svrTrans.tr.Close() + + s := grpc.CreateStream(ctx, 1, func(i int) {}, mocks.MockServiceName+"/"+mocks.MockStreamingMethod) + srvConn := newServerConn(svrTrans.tr, s) + + transHdl.handleFunc(s, svrTrans, srvConn) + }) + } +} diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index c3b4d3b6b6..f249f84242 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -122,136 +122,149 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { tr.HandleStreams(func(s *grpcTransport.Stream) { gofunc.GoFunc(ctx, func() { - ri := svrTrans.pool.Get().(rpcinfo.RPCInfo) - rCtx := rpcinfo.NewCtxWithRPCInfo(s.Context(), ri) - defer func() { - // reset rpcinfo for performance (PR #584) - if rpcinfo.PoolEnabled() { - ri = t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr()) - svrTrans.pool.Put(ri) - } - }() - - ink := ri.Invocation().(rpcinfo.InvocationSetter) - sm := s.Method() - if sm != "" && sm[0] == '/' { - sm = sm[1:] - } - pos := strings.LastIndex(sm, "/") - if pos == -1 { - errDesc := fmt.Sprintf("malformed method name, method=%q", s.Method()) - tr.WriteStatus(s, status.New(codes.Internal, errDesc)) - return - } - methodName := sm[pos+1:] - ink.SetMethodName(methodName) - - if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { - if err := mutableTo.SetMethod(methodName); err != nil { - errDesc := fmt.Sprintf("setMethod failed in streaming, method=%s, error=%s", methodName, err.Error()) - _ = tr.WriteStatus(s, status.New(codes.Internal, errDesc)) - return - } - } + t.handleFunc(s, svrTrans, conn) + }) + }, func(ctx context.Context, method string) context.Context { + return ctx + }) + return nil +} - var serviceName string - idx := strings.LastIndex(sm[:pos], ".") - if idx == -1 { - ink.SetPackageName("") - serviceName = sm[0:pos] - } else { - ink.SetPackageName(sm[:idx]) - serviceName = sm[idx+1 : pos] - } - ink.SetServiceName(serviceName) - - // set grpc transport flag before execute metahandler - rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.GRPC) - var err error - for _, shdlr := range t.opt.StreamingMetaHandlers { - rCtx, err = shdlr.OnReadStream(rCtx) - if err != nil { - tr.WriteStatus(s, convertStatus(err)) - return - } - } - rCtx = t.startTracer(rCtx, ri) - defer func() { - panicErr := recover() - if panicErr != nil { - if conn != nil { - klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack())) - } else { - klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack())) - } - } - t.finishTracer(rCtx, ri, err, panicErr) - }() - - // set recv grpc compressor at server to decode the pack from client - remote.SetRecvCompressor(ri, s.RecvCompress()) - // set send grpc compressor at server to encode reply pack - remote.SetSendCompressor(ri, s.SendCompress()) - - svcInfo := t.svcSearcher.SearchService(serviceName, methodName, true) - var methodInfo serviceinfo.MethodInfo - if svcInfo != nil { - methodInfo = svcInfo.MethodInfo(methodName) - } +func (t *svrTransHandler) handleFunc(s *grpcTransport.Stream, svrTrans *SvrTrans, conn net.Conn) { + tr := svrTrans.tr + ri := svrTrans.pool.Get().(rpcinfo.RPCInfo) + rCtx := rpcinfo.NewCtxWithRPCInfo(s.Context(), ri) + defer func() { + // reset rpcinfo for performance (PR #584) + if rpcinfo.PoolEnabled() { + ri = t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr()) + svrTrans.pool.Put(ri) + } + }() + + ink := ri.Invocation().(rpcinfo.InvocationSetter) + sm := s.Method() + if sm != "" && sm[0] == '/' { + sm = sm[1:] + } + pos := strings.LastIndex(sm, "/") + if pos == -1 { + errDesc := fmt.Sprintf("malformed method name, method=%q", s.Method()) + tr.WriteStatus(s, status.New(codes.Internal, errDesc)) + return + } + methodName := sm[pos+1:] + ink.SetMethodName(methodName) + + if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { + if err := mutableTo.SetMethod(methodName); err != nil { + errDesc := fmt.Sprintf("setMethod failed in streaming, method=%s, error=%s", methodName, err.Error()) + _ = tr.WriteStatus(s, status.New(codes.Internal, errDesc)) + return + } + } - rawStream := NewStream(rCtx, svcInfo, newServerConn(tr, s), t) - st := newStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint) - - // bind stream into ctx, in order to let user set header and trailer by provided api in meta_api.go - rCtx = streaming.NewCtxWithStream(rCtx, st) - - if methodInfo == nil { - unknownServiceHandlerFunc := t.opt.GRPCUnknownServiceHandler - if unknownServiceHandlerFunc != nil { - rpcinfo.Record(rCtx, ri, stats.ServerHandleStart, nil) - err = unknownServiceHandlerFunc(rCtx, methodName, st) - if err != nil { - err = kerrors.ErrBiz.WithCause(err) - } - } else { - if svcInfo == nil { - err = remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", serviceName)) - } else { - err = remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) - } - } + var serviceName string + idx := strings.LastIndex(sm[:pos], ".") + if idx == -1 { + ink.SetPackageName("") + serviceName = sm[0:pos] + } else { + ink.SetPackageName(sm[:idx]) + serviceName = sm[idx+1 : pos] + } + ink.SetServiceName(serviceName) + + // set grpc transport flag before execute metahandler + rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.GRPC) + var err error + for _, shdlr := range t.opt.StreamingMetaHandlers { + rCtx, err = shdlr.OnReadStream(rCtx) + if err != nil { + tr.WriteStatus(s, convertStatus(err)) + return + } + } + rCtx = t.startTracer(rCtx, ri) + defer func() { + panicErr := recover() + if panicErr != nil { + if conn != nil { + klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack())) } else { - if streaming.UnaryCompatibleMiddleware(methodInfo.StreamingMode(), t.opt.CompatibleMiddlewareForUnary) { - // making streaming unary APIs capable of using the same server middleware as non-streaming APIs - // note: rawStream skips recv/send middleware for unary API requests to avoid confusion - err = invokeStreamUnaryHandler(rCtx, rawStream, methodInfo, t.inkHdlFunc, ri) - } else { - err = t.inkHdlFunc(rCtx, &streaming.Args{Stream: st}, nil) - } + klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack())) } + } + t.finishTracer(rCtx, ri, err, panicErr) + }() + + // set recv grpc compressor at server to decode the pack from client + remote.SetRecvCompressor(ri, s.RecvCompress()) + // set send grpc compressor at server to encode reply pack + remote.SetSendCompressor(ri, s.SendCompress()) + + svcInfo := t.svcSearcher.SearchService(serviceName, methodName, true) + var methodInfo serviceinfo.MethodInfo + if svcInfo != nil { + methodInfo = svcInfo.MethodInfo(methodName) + } + rawStream := &stream{ + ctx: rCtx, + svcInfo: svcInfo, + conn: newServerConn(tr, s), + handler: t, + } + // inject rawStream so that GetServerConn only relies on it + rCtx = context.WithValue(rCtx, serverConnKey{}, rawStream) + st := newStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint) + // bind stream into ctx, in order to let user set header and trailer by provided api in meta_api.go + rCtx = streaming.NewCtxWithStream(rCtx, st) + // GetServerConn could retrieve rawStream by Stream.Context().Value(serverConnKey{}) + rawStream.ctx = rCtx + + if methodInfo == nil { + unknownServiceHandlerFunc := t.opt.GRPCUnknownServiceHandler + if unknownServiceHandlerFunc != nil { + rpcinfo.Record(rCtx, ri, stats.ServerHandleStart, nil) + err = unknownServiceHandlerFunc(rCtx, methodName, st) if err != nil { - tr.WriteStatus(s, convertStatus(err)) - t.OnError(rCtx, err, conn) - return + err = kerrors.ErrBiz.WithCause(err) } - if bizStatusErr := ri.Invocation().BizStatusErr(); bizStatusErr != nil { - var st *status.Status - if sterr, ok := bizStatusErr.(status.Iface); ok { - st = sterr.GRPCStatus() - } else { - st = status.New(codes.Internal, bizStatusErr.BizMessage()) - } - s.SetBizStatusErr(bizStatusErr) - tr.WriteStatus(s, st) - return + } else { + if svcInfo == nil { + err = remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", serviceName)) + } else { + err = remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) } - tr.WriteStatus(s, status.New(codes.OK, "")) - }) - }, func(ctx context.Context, method string) context.Context { - return ctx - }) - return nil + } + } else { + if streaming.UnaryCompatibleMiddleware(methodInfo.StreamingMode(), t.opt.CompatibleMiddlewareForUnary) { + // making streaming unary APIs capable of using the same server middleware as non-streaming APIs + // note: rawStream skips recv/send middleware for unary API requests to avoid confusion + err = invokeStreamUnaryHandler(rCtx, rawStream, methodInfo, t.inkHdlFunc, ri) + } else { + err = t.inkHdlFunc(rCtx, &streaming.Args{Stream: st}, nil) + } + } + + if err != nil { + tr.WriteStatus(s, convertStatus(err)) + t.OnError(rCtx, err, conn) + return + } + if bizStatusErr := ri.Invocation().BizStatusErr(); bizStatusErr != nil { + var st *status.Status + if sterr, ok := bizStatusErr.(status.Iface); ok { + st = sterr.GRPCStatus() + } else { + st = status.New(codes.Internal, bizStatusErr.BizMessage()) + } + s.SetBizStatusErr(bizStatusErr) + tr.WriteStatus(s, st) + return + } + tr.WriteStatus(s, status.New(codes.OK, "")) } // invokeStreamUnaryHandler allows unary APIs over HTTP2 to use the same server middleware as non-streaming APIs. diff --git a/pkg/remote/trans/nphttp2/server_handler_test.go b/pkg/remote/trans/nphttp2/server_handler_test.go index 629faecea9..79d518944b 100644 --- a/pkg/remote/trans/nphttp2/server_handler_test.go +++ b/pkg/remote/trans/nphttp2/server_handler_test.go @@ -74,7 +74,7 @@ func TestServerHandler(t *testing.T) { npConn.mockSettingFrame() tr, err := newMockServerTransport(npConn) test.Assert(t, err == nil, err) - s := grpc.CreateStream(1, func(i int) {}) + s := grpc.CreateStream(context.Background(), 1, func(i int) {}, "") serverConn := newServerConn(tr, s) defer serverConn.Close() diff --git a/pkg/remote/trans/nphttp2/stream_test.go b/pkg/remote/trans/nphttp2/stream_test.go index fa5c2f4d00..82181aa606 100644 --- a/pkg/remote/trans/nphttp2/stream_test.go +++ b/pkg/remote/trans/nphttp2/stream_test.go @@ -17,6 +17,7 @@ package nphttp2 import ( + "context" "testing" "github.com/cloudwego/kitex/internal/test" @@ -31,7 +32,7 @@ func TestStream(t *testing.T) { conn.mockSettingFrame() tr, err := newMockServerTransport(conn) test.Assert(t, err == nil, err) - s := grpc.CreateStream(1, func(i int) {}) + s := grpc.CreateStream(context.Background(), 1, func(i int) {}, "") serverConn := newServerConn(tr, s) defer serverConn.Close() diff --git a/server/invoke.go b/server/invoke.go index d26c8b2aa7..83ea70752a 100644 --- a/server/invoke.go +++ b/server/invoke.go @@ -104,6 +104,7 @@ func (s *tInvoker) Init() (err error) { if len(s.server.svcs.svcMap) == 0 { return errors.New("run: no service. Use RegisterService to set one") } + s.buildFullInvokeChain() s.initBasicRemoteOption() // for server trans info handler if len(s.server.opt.MetaHandlers) > 0 { diff --git a/server/middlewares.go b/server/middlewares.go index a8783b3030..2c9b12f363 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -21,6 +21,7 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" ) func serverTimeoutMW(initCtx context.Context) endpoint.Middleware { @@ -45,3 +46,22 @@ func serverTimeoutMW(initCtx context.Context) endpoint.Middleware { } } } + +// newCtxInjectMW must be placed at the end +func newCtxInjectMW() endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request, response interface{}) (err error) { + args, ok := request.(*streaming.Args) + if !ok { + return next(ctx, request, response) + } + // use contextStream to wrap the original Stream and rewrite Context() + // so that we can get this ctx by Stream.Context() + args.Stream = contextStream{ + Stream: args.Stream, + ctx: ctx, + } + return next(ctx, request, response) + } + } +} diff --git a/server/option_test.go b/server/option_test.go index ea7818a9cb..44fe27694d 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -49,23 +49,18 @@ func TestOptionDebugInfo(t *testing.T) { var opts []Option md := newMockDiagnosis() opts = append(opts, WithDiagnosisService(md)) - buildMw := 0 opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { - buildMw++ return func(ctx context.Context, req, resp interface{}) (err error) { return next(ctx, req, resp) } })) opts = append(opts, WithMiddlewareBuilder(func(ctx context.Context) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { - buildMw++ return next } })) svr := NewServer(opts...) - // check middleware build - test.Assert(t, buildMw == 2) // check probe result pp := md.ProbePairs() diff --git a/server/server.go b/server/server.go index 2556b1c81b..6e3c1fb185 100644 --- a/server/server.go +++ b/server/server.go @@ -100,8 +100,6 @@ func (s *server) init() { ds.RegisterProbeFunc(diagnosis.ChangeEventsKey, s.opt.Events.Dump) } backup.Init(s.opt.BackupOpt) - s.buildInvokeChain() - s.buildStreamInvokeChain() } func fillContext(opt *internal_server.Options) context.Context { @@ -172,9 +170,32 @@ func (s *server) initOrResetRPCInfoFunc() func(rpcinfo.RPCInfo, net.Addr) rpcinf func (s *server) buildInvokeChain() { innerHandlerEp := s.invokeHandleEndpoint() + // TODO(DMwangnima): this is a workaround to fix ctx diverge problem temporarily, + // it should be removed after the new streaming api with ctx published. + // if there is streaming method, make sure the ctxInjectMW is the last middleware + s.fixStreamCtxDiverge() s.eps = endpoint.Chain(s.mws...)(innerHandlerEp) } +// buildFullInvokeChain builds the invoke chain for ping-pong and streaming +func (s *server) buildFullInvokeChain() { + s.buildInvokeChain() + s.buildStreamInvokeChain() +} + +// fixStreamCtxDiverge is a workaround to resolve stream ctx diverge problem in server side +// when there is streaming method, add the ctxInjectMW to wrap the Stream +func (s *server) fixStreamCtxDiverge() { + for _, svc := range s.svcs.svcMap { + for _, method := range svc.svcInfo.Methods { + if method.IsStreaming() { + s.mws = append(s.mws, newCtxInjectMW()) + return + } + } + } +} + // RegisterService should not be called by users directly. func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error { s.Lock() @@ -208,6 +229,8 @@ func (s *server) Run() (err error) { s.Lock() s.isRun = true s.Unlock() + // build invoker chain here since we need to get some svc information to add MW + s.buildFullInvokeChain() if err = s.check(); err != nil { return err } diff --git a/server/server_test.go b/server/server_test.go index ede109d170..e91d3a73ce 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -48,6 +48,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" + "github.com/cloudwego/kitex/pkg/streaming" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/transport" @@ -1126,3 +1127,211 @@ func withGRPCTransport() Option { o.RemoteOpt.SvrHandlerFactory = nphttp2.NewSvrTransHandlerFactory() }} } + +type mockStream struct { + streaming.Stream + ctx context.Context +} + +func (s *mockStream) Context() context.Context { + return s.ctx +} + +type streamingMethodArg struct { + methodName string + mode serviceinfo.StreamingMode + streamHdlr serviceinfo.MethodHandler +} + +func newStreamingServer(svcName string, args []streamingMethodArg, mws []endpoint.Middleware) *server { + methods := make(map[string]serviceinfo.MethodInfo) + for _, arg := range args { + methods[arg.methodName] = serviceinfo.NewMethodInfo(arg.streamHdlr, nil, nil, false, serviceinfo.WithStreamingMode(arg.mode)) + } + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: svcName, + Methods: methods, + } + svr := &server{ + svcs: &services{ + svcMap: map[string]*service{ + svcName: { + svcInfo: svcInfo, + }, + }, + }, + mws: mws, + opt: internal_server.NewOptions(nil), + } + return svr +} + +func TestStreamCtxDiverge(t *testing.T) { + testcases := []struct { + methodName string + mode serviceinfo.StreamingMode + }{ + { + methodName: "ClientStreaming", + mode: serviceinfo.StreamingClient, + }, + { + methodName: "ServerStreaming", + mode: serviceinfo.StreamingServer, + }, + { + methodName: "BidiStreaming", + mode: serviceinfo.StreamingBidirectional, + }, + } + + testKey := "key" + testVal := "val" + testService := "test" + streamHdlr := func(ctx context.Context, handler, arg, result interface{}) error { + st, ok := arg.(*streaming.Args) + test.Assert(t, ok) + val, ok := st.Stream.Context().Value(testKey).(string) + test.Assert(t, ok) + test.Assert(t, val == testVal) + return nil + } + + var args []streamingMethodArg + for _, tc := range testcases { + args = append(args, streamingMethodArg{ + methodName: tc.methodName, + mode: tc.mode, + streamHdlr: streamHdlr, + }) + } + mws := []endpoint.Middleware{ + // treat it as user middleware + // user would modify the ctx here + func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + ctx = context.WithValue(ctx, testKey, testVal) + return next(ctx, req, resp) + } + }, + } + svr := newStreamingServer(testService, args, mws) + svr.buildInvokeChain() + + for _, tc := range testcases { + t.Run(tc.methodName, func(t *testing.T) { + ri := svr.initOrResetRPCInfoFunc()(nil, nil) + ink, ok := ri.Invocation().(rpcinfo.InvocationSetter) + test.Assert(t, ok) + ink.SetServiceName(testService) + ink.SetMethodName(tc.methodName) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + mock := &mockStream{ + ctx: ctx, + } + err := svr.eps(ctx, &streaming.Args{Stream: mock}, nil) + test.Assert(t, err == nil, err) + }) + } +} + +func Test_server_fixStreamCtxDiverge(t *testing.T) { + methodInfoFunc := func(mode serviceinfo.StreamingMode) serviceinfo.MethodInfo { + return serviceinfo.NewMethodInfo(nil, nil, nil, false, serviceinfo.WithStreamingMode(mode)) + } + testcases := []struct { + desc string + svcMap map[string]*service + expectInjectMW bool + }{ + { + desc: "single service without streaming method", + svcMap: map[string]*service{ + "service": { + svcInfo: &serviceinfo.ServiceInfo{ + ServiceName: "service", + Methods: map[string]serviceinfo.MethodInfo{ + "nonStreamingMethod": methodInfoFunc(serviceinfo.StreamingNone), + }, + }, + }, + }, + }, + { + desc: "single service with streaming methods", + svcMap: map[string]*service{ + "service": { + svcInfo: &serviceinfo.ServiceInfo{ + ServiceName: "service", + Methods: map[string]serviceinfo.MethodInfo{ + "ClientStreamingMethod": methodInfoFunc(serviceinfo.StreamingClient), + "ServerStreamingMethod": methodInfoFunc(serviceinfo.StreamingServer), + "BidiStreamingMethod": methodInfoFunc(serviceinfo.StreamingBidirectional), + }, + }, + }, + }, + expectInjectMW: true, + }, + { + desc: "multiple services without streaming method", + svcMap: map[string]*service{ + "service0": { + svcInfo: &serviceinfo.ServiceInfo{ + ServiceName: "service0", + Methods: map[string]serviceinfo.MethodInfo{ + "nonStreamingMethod0": methodInfoFunc(serviceinfo.StreamingNone), + }, + }, + }, + "service1": { + svcInfo: &serviceinfo.ServiceInfo{ + ServiceName: "service1", + Methods: map[string]serviceinfo.MethodInfo{ + "nonStreamingMethod1": methodInfoFunc(serviceinfo.StreamingNone), + }, + }, + }, + }, + }, + { + desc: "multiple services with streaming methods", + svcMap: map[string]*service{ + "service0": { + svcInfo: &serviceinfo.ServiceInfo{ + ServiceName: "service0", + Methods: map[string]serviceinfo.MethodInfo{ + "ClientStreamingMethod": methodInfoFunc(serviceinfo.StreamingClient), + "ServerStreamingMethod": methodInfoFunc(serviceinfo.StreamingServer), + }, + }, + }, + "service1": { + svcInfo: &serviceinfo.ServiceInfo{ + ServiceName: "service1", + Methods: map[string]serviceinfo.MethodInfo{ + "BidiStreamingMethod": methodInfoFunc(serviceinfo.StreamingBidirectional), + }, + }, + }, + }, + expectInjectMW: true, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + svr := &server{ + svcs: &services{ + svcMap: tc.svcMap, + }, + } + svr.fixStreamCtxDiverge() + if tc.expectInjectMW { + test.Assert(t, len(svr.mws) == 1) + } else { + test.Assert(t, len(svr.mws) == 0) + } + }) + } +} diff --git a/server/stream.go b/server/stream.go index 60cf863ca6..d171fe5883 100644 --- a/server/stream.go +++ b/server/stream.go @@ -44,3 +44,15 @@ func (s *server) invokeSendEndpoint() endpoint.SendEndpoint { return stream.SendMsg(req) } } + +// contextStream is responsible for solving ctx diverge in server side streaming. +// it receives the ctx from previous middlewares and the Stream that exposed to users,then rewrite +// Context() method so that users could call Stream.Context() in handler to get the processed ctx. +type contextStream struct { + streaming.Stream + ctx context.Context +} + +func (cs contextStream) Context() context.Context { + return cs.ctx +} From 85535233c8dd3a91fd2e4e3d37dd4704979fe69d Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Wed, 28 Aug 2024 20:56:15 +0800 Subject: [PATCH 58/70] feat(retry): support Mixed Retry which integrating Failure Retry and Backup Request (#1509) Co-authored-by: Li2CO3 --- client/callopt/options.go | 23 +- client/callopt/options_test.go | 12 + client/option.go | 29 +- client/option_test.go | 121 +++-- pkg/retry/backup.go | 39 +- pkg/retry/backup_retryer.go | 44 +- pkg/retry/backup_test.go | 68 +++ pkg/retry/failure.go | 109 ++++- pkg/retry/failure_retryer.go | 200 ++++---- pkg/retry/failure_test.go | 600 ++++++++++++++++++++++++ pkg/retry/mixed.go | 82 ++++ pkg/retry/mixed_retryer.go | 270 +++++++++++ pkg/retry/mixed_test.go | 625 +++++++++++++++++++++++++ pkg/retry/policy.go | 163 ++----- pkg/retry/policy_test.go | 122 +++-- pkg/retry/retryer.go | 18 +- pkg/retry/retryer_test.go | 805 +++++---------------------------- pkg/retry/util.go | 2 +- 18 files changed, 2267 insertions(+), 1065 deletions(-) create mode 100644 pkg/retry/mixed.go create mode 100644 pkg/retry/mixed_retryer.go create mode 100644 pkg/retry/mixed_test.go diff --git a/client/callopt/options.go b/client/callopt/options.go index c2ec7174b3..891b42157c 100644 --- a/client/callopt/options.go +++ b/client/callopt/options.go @@ -182,21 +182,28 @@ func WithTag(key, val string) Option { } // WithRetryPolicy sets the retry policy for a RPC call. -// Build retry.Policy with retry.BuildFailurePolicy or retry.BuildBackupRequest instead of building retry.Policy directly. +// Build retry.Policy with retry.BuildFailurePolicy or retry.BuildBackupRequest or retry.BuildMixedPolicy +// instead of building retry.Policy directly. +// // Demos are provided below: // -// demo1. call with failure retry policy, default retry error is Timeout -// `resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildFailurePolicy(retry.NewFailurePolicy())))` -// demo2. call with backup request policy -// `bp := retry.NewBackupPolicy(10) -// bp.WithMaxRetryTimes(1) -// resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildBackupRequest(bp)))` +// demo1. call with failure retry policy, default retry error is Timeout +// `resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildFailurePolicy(retry.NewFailurePolicy())))` +// demo2. call with backup request policy +// `bp := retry.NewBackupPolicy(10) +// `bp.WithMaxRetryTimes(1)` +// `resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildBackupRequest(bp)))` +// demo2. call with miexed request policy +// `bp := retry.BuildMixedPolicy(10) +// `resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildMixedPolicy(retry.NewMixedPolicy(10))))` func WithRetryPolicy(p retry.Policy) Option { return Option{f: func(o *CallOptions, di *strings.Builder) { if !p.Enable { return } - if p.Type == retry.BackupType { + if p.Type == retry.MixedType { + di.WriteString("WithMixedRetry") + } else if p.Type == retry.BackupType { di.WriteString("WithBackupRequest") } else { di.WriteString("WithFailureRetry") diff --git a/client/callopt/options_test.go b/client/callopt/options_test.go index 22baa6b6a5..d0ddfb47f3 100644 --- a/client/callopt/options_test.go +++ b/client/callopt/options_test.go @@ -100,6 +100,18 @@ func TestApply(t *testing.T) { test.Assert(t, co.RetryPolicy.Enable) test.Assert(t, co.RetryPolicy.FailurePolicy != nil) + // WithRetryPolicy + option = WithRetryPolicy(retry.BuildMixedPolicy(retry.NewMixedPolicy(10))) + _, co = Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) + test.Assert(t, co.RetryPolicy.Enable) + test.Assert(t, co.RetryPolicy.MixedPolicy != nil) + + // WithRetryPolicy + option = WithRetryPolicy(retry.BuildBackupRequest(retry.NewBackupPolicy(10))) + _, co = Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) + test.Assert(t, co.RetryPolicy.Enable) + test.Assert(t, co.RetryPolicy.BackupPolicy != nil) + // WithRetryPolicy pass empty struct option = WithRetryPolicy(retry.Policy{}) _, co = Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) diff --git a/client/option.go b/client/option.go index 29a0161908..ce2235a33d 100644 --- a/client/option.go +++ b/client/option.go @@ -325,12 +325,13 @@ func WithFailureRetry(p *retry.FailurePolicy) Option { if p == nil { return } - di.Push(fmt.Sprintf("WithFailureRetry(%+v)", *p)) + di.Push(fmt.Sprintf("WithFailureRetry(%+v)", p)) if o.RetryMethodPolicies == nil { o.RetryMethodPolicies = make(map[string]retry.Policy) } - if o.RetryMethodPolicies[retry.Wildcard].BackupPolicy != nil { - panic("BackupPolicy has been setup, cannot support Failure Retry at same time") + if o.RetryMethodPolicies[retry.Wildcard].MixedPolicy != nil || + o.RetryMethodPolicies[retry.Wildcard].BackupPolicy != nil { + panic("MixedPolicy or BackupPolicy has been setup, cannot support Failure Retry at same time") } o.RetryMethodPolicies[retry.Wildcard] = retry.BuildFailurePolicy(p) }} @@ -342,17 +343,33 @@ func WithBackupRequest(p *retry.BackupPolicy) Option { if p == nil { return } - di.Push(fmt.Sprintf("WithBackupRequest(%+v)", *p)) + di.Push(fmt.Sprintf("WithBackupRequest(%+v)", p)) if o.RetryMethodPolicies == nil { o.RetryMethodPolicies = make(map[string]retry.Policy) } - if o.RetryMethodPolicies[retry.Wildcard].FailurePolicy != nil { - panic("BackupPolicy has been setup, cannot support Failure Retry at same time") + if o.RetryMethodPolicies[retry.Wildcard].MixedPolicy != nil || + o.RetryMethodPolicies[retry.Wildcard].FailurePolicy != nil { + panic("MixedPolicy or BackupPolicy has been setup, cannot support Failure Retry at same time") } o.RetryMethodPolicies[retry.Wildcard] = retry.BuildBackupRequest(p) }} } +// WithMixedRetry sets the mixed retry policy for client, it will take effect for all methods. +func WithMixedRetry(p *retry.MixedPolicy) Option { + return Option{F: func(o *client.Options, di *utils.Slice) { + if p == nil { + return + } + di.Push(fmt.Sprintf("WithMixedRetry(%+v)", p)) + if o.RetryMethodPolicies == nil { + o.RetryMethodPolicies = make(map[string]retry.Policy) + } + // no need to check if BackupPolicy or FailurePolicy are been setup, just let mixed retry replace it + o.RetryMethodPolicies[retry.Wildcard] = retry.BuildMixedPolicy(p) + }} +} + // WithRetryMethodPolicies sets the retry policy for method. // The priority is higher than WithFailureRetry and WithBackupRequest. Only the methods which are not included by // this config will use the policy that is configured by WithFailureRetry or WithBackupRequest . diff --git a/client/option_test.go b/client/option_test.go index 883f67481c..4e17645f19 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -19,6 +19,7 @@ package client import ( "context" "crypto/tls" + "errors" "fmt" "reflect" "testing" @@ -53,46 +54,76 @@ import ( ) func TestRetryOptionDebugInfo(t *testing.T) { - fp := retry.NewFailurePolicy() - fp.WithDDLStop() - expectPolicyStr := "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:false DDLStop:true " + - "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:none CfgItems:map[]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" - policyStr := fmt.Sprintf("WithFailureRetry(%+v)", fp) - test.Assert(t, policyStr == expectPolicyStr, policyStr) - - fp.WithFixedBackOff(10) - expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:false DDLStop:true " + - "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:fixed CfgItems:map[fix_ms:10]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" - policyStr = fmt.Sprintf("WithFailureRetry(%+v)", fp) - test.Assert(t, policyStr == expectPolicyStr, policyStr) - - fp.WithRandomBackOff(10, 20) - fp.DisableChainRetryStop() - expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true DDLStop:true " + - "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:random CfgItems:map[max_ms:20 min_ms:10]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" - policyStr = fmt.Sprintf("WithFailureRetry(%+v)", fp) - test.Assert(t, policyStr == expectPolicyStr, policyStr) - - fp.WithRetrySameNode() - expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true DDLStop:true " + - "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:random CfgItems:map[max_ms:20 min_ms:10]} RetrySameNode:true ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" - policyStr = fmt.Sprintf("WithFailureRetry(%+v)", fp) - test.Assert(t, policyStr == expectPolicyStr, policyStr) - - fp.WithSpecifiedResultRetry(&retry.ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - return false - }}) - expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true DDLStop:true " + - "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:random CfgItems:map[max_ms:20 min_ms:10]} RetrySameNode:true ShouldResultRetry:{ErrorRetry:true, RespRetry:false}})" - policyStr = fmt.Sprintf("WithFailureRetry(%+v)", fp) - test.Assert(t, policyStr == expectPolicyStr, policyStr) + t.Run("FailurePolicy", func(t *testing.T) { + fp := retry.NewFailurePolicy() + fp.WithDDLStop() + expectPolicyStr := "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:false DDLStop:true " + + "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:none CfgItems:map[]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" + opt := WithFailureRetry(fp) + err := checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + + fp.WithFixedBackOff(10) + expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:false DDLStop:true " + + "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:fixed CfgItems:map[fix_ms:10]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" + opt = WithFailureRetry(fp) + err = checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + + fp.WithRandomBackOff(10, 20) + fp.DisableChainRetryStop() + expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true DDLStop:true " + + "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:random CfgItems:map[max_ms:20 min_ms:10]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" + opt = WithFailureRetry(fp) + err = checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + + fp.WithRetrySameNode() + expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true DDLStop:true " + + "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:random CfgItems:map[max_ms:20 min_ms:10]} RetrySameNode:true ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" + opt = WithFailureRetry(fp) + err = checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + + fp.WithSpecifiedResultRetry(&retry.ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return false + }}) + expectPolicyStr = "WithFailureRetry({StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true DDLStop:true " + + "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:random CfgItems:map[max_ms:20 min_ms:10]} RetrySameNode:true ShouldResultRetry:{ErrorRetry:true, RespRetry:false}})" + opt = WithFailureRetry(fp) + err = checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + }) - bp := retry.NewBackupPolicy(20) - expectPolicyStr = "WithBackupRequest({RetryDelayMS:20 StopPolicy:{MaxRetryTimes:1 MaxDurationMS:0 DisableChainStop:false " + - "DDLStop:false CBPolicy:{ErrorRate:0.1}} RetrySameNode:false})" - policyStr = fmt.Sprintf("WithBackupRequest(%+v)", bp) - test.Assert(t, policyStr == expectPolicyStr, policyStr) - WithBackupRequest(bp) + t.Run("FailurePolicy", func(t *testing.T) { + bp := retry.NewBackupPolicy(20) + expectPolicyStr := "WithBackupRequest({RetryDelayMS:20 StopPolicy:{MaxRetryTimes:1 MaxDurationMS:0 DisableChainStop:false " + + "DDLStop:false CBPolicy:{ErrorRate:0.1}} RetrySameNode:false})" + opt := WithBackupRequest(bp) + err := checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + }) + + t.Run("MixedPolicy", func(t *testing.T) { + mp := retry.NewMixedPolicy(100) + mp.WithDDLStop() + expectPolicyStr := "WithMixedRetry({RetryDelayMS:100 StopPolicy:{MaxRetryTimes:1 MaxDurationMS:0 DisableChainStop:false " + + "DDLStop:true CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:none CfgItems:map[]} RetrySameNode:false " + + "ShouldResultRetry:{ErrorRetry:false, RespRetry:false}})" + opt := WithMixedRetry(mp) + err := checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + + mp.WithSpecifiedResultRetry(&retry.ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return false + }}) + expectPolicyStr = "WithMixedRetry({RetryDelayMS:100 StopPolicy:{MaxRetryTimes:1 MaxDurationMS:0 DisableChainStop:false " + + "DDLStop:true CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:none CfgItems:map[]} RetrySameNode:false " + + "ShouldResultRetry:{ErrorRetry:true, RespRetry:false}})" + opt = WithMixedRetry(mp) + err = checkOneOptionDebugInfo(t, opt, expectPolicyStr) + test.Assert(t, err == nil, err) + }) } func TestRetryOption(t *testing.T) { @@ -708,3 +739,15 @@ func TestWithGRPCTLSConfig(t *testing.T) { opts := client.NewOptions([]client.Option{WithGRPCTLSConfig(cfg)}) test.Assert(t, opts.GRPCConnectOpts != nil) } + +func checkOneOptionDebugInfo(t *testing.T, opt Option, expectStr string) error { + o := &Options{} + o.Apply([]Option{opt}) + if len(o.DebugInfo) != 1 { + return errors.New("length of DebugInfo is unexpected") + } + if o.DebugInfo[0] != expectStr { + return fmt.Errorf("DebugInfo not match with expect str:\n debugInfo=%s", o.DebugInfo[0]) + } + return nil +} diff --git a/pkg/retry/backup.go b/pkg/retry/backup.go index fa3bda5da1..b189f38293 100644 --- a/pkg/retry/backup.go +++ b/pkg/retry/backup.go @@ -20,7 +20,10 @@ import ( "fmt" ) -const maxBackupRetryTimes = 2 +const ( + maxBackupRetryTimes = 2 + defaultBackupRetryTimes = 1 +) // NewBackupPolicy init backup request policy // the param delayMS is suggested to set as TP99 @@ -31,7 +34,7 @@ func NewBackupPolicy(delayMS uint32) *BackupPolicy { p := &BackupPolicy{ RetryDelayMS: delayMS, StopPolicy: StopPolicy{ - MaxRetryTimes: 1, + MaxRetryTimes: defaultBackupRetryTimes, DisableChainStop: false, CBPolicy: CBPolicy{ ErrorRate: defaultCBErrRate, @@ -71,3 +74,35 @@ func (p *BackupPolicy) WithRetrySameNode() { func (p *BackupPolicy) String() string { return fmt.Sprintf("{RetryDelayMS:%+v StopPolicy:%+v RetrySameNode:%+v}", p.RetryDelayMS, p.StopPolicy, p.RetrySameNode) } + +// Equals to check if BackupPolicy is equal +func (p *BackupPolicy) Equals(np *BackupPolicy) bool { + if p == nil { + return np == nil + } + if np == nil { + return false + } + if p.RetryDelayMS != np.RetryDelayMS { + return false + } + if p.StopPolicy != np.StopPolicy { + return false + } + if p.RetrySameNode != np.RetrySameNode { + return false + } + + return true +} + +func (p *BackupPolicy) DeepCopy() *BackupPolicy { + if p == nil { + return nil + } + return &BackupPolicy{ + RetryDelayMS: p.RetryDelayMS, + StopPolicy: p.StopPolicy, // not a pointer, will copy the value here + RetrySameNode: p.RetrySameNode, + } +} diff --git a/pkg/retry/backup_retryer.go b/pkg/retry/backup_retryer.go index 1ff8ff8e3c..21980f57ca 100644 --- a/pkg/retry/backup_retryer.go +++ b/pkg/retry/backup_retryer.go @@ -30,7 +30,6 @@ import ( "github.com/cloudwego/kitex/pkg/circuitbreak" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/utils" ) @@ -48,16 +47,17 @@ func newBackupRetryer(policy Policy, cbC *cbContainer) (Retryer, error) { type backupRetryer struct { enable bool + retryDelay time.Duration policy *BackupPolicy cbContainer *cbContainer - retryDelay time.Duration sync.RWMutex errMsg string } type resultWrapper struct { - ri rpcinfo.RPCInfo - err error + ri rpcinfo.RPCInfo + resp interface{} + err error } // ShouldRetry implements the Retryer interface. @@ -92,12 +92,13 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc retryTimes := r.policy.StopPolicy.MaxRetryTimes retryDelay := r.retryDelay r.RUnlock() + var callTimes int32 = 0 var callCosts utils.StringBuilder callCosts.RawStringBuilder().Grow(32) var recordCostDoing int32 = 0 var abort int32 = 0 - finishedCount := 0 + finishedErrCount := 0 // notice: buff num of chan is very important here, it cannot less than call times, or the below chan receive will block done := make(chan *resultWrapper, retryTimes+1) cbKey, _ := r.cbContainer.cbCtl.GetKey(ctx, req) @@ -126,7 +127,7 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc if panicInfo := recover(); panicInfo != nil { e = panicToErr(ctx, panicInfo, firstRI) } - done <- &resultWrapper{cRI, e} + done <- &resultWrapper{ri: cRI, err: e} }() ct := atomic.AddInt32(&callTimes, 1) callStart := time.Now() @@ -152,7 +153,7 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc // There will be only one request (goroutine) pass the `checkRPCState`, others will skip decoding // and return `ErrRPCFinish`, to avoid concurrent write to response and save the cost of decoding. // We can safely ignore this error and wait for the response of the passed goroutine. - if finishedCount++; finishedCount >= retryTimes+1 { + if finishedErrCount++; finishedErrCount >= retryTimes+1 { // But if all requests return this error, it must be a bug, preventive panic to avoid dead loop panic(errUnexpectedFinish) } @@ -178,29 +179,21 @@ func (r *backupRetryer) UpdatePolicy(rp Policy) (err error) { r.Unlock() return nil } - var errMsg string if rp.BackupPolicy == nil || rp.Type != BackupType { - errMsg = "BackupPolicy is nil or retry type not match, cannot do update in backupRetryer" - err = errors.New(errMsg) + err = errors.New("BackupPolicy is nil or retry type not match, cannot do update in backupRetryer") } - if errMsg == "" && (rp.BackupPolicy.RetryDelayMS == 0 || rp.BackupPolicy.StopPolicy.MaxRetryTimes < 0 || - rp.BackupPolicy.StopPolicy.MaxRetryTimes > maxBackupRetryTimes) { - errMsg = "invalid backup request delay duration or retryTimes" - err = errors.New(errMsg) + if err == nil && rp.BackupPolicy.RetryDelayMS == 0 { + err = errors.New("invalid retry delay duration in backupRetryer") } - if errMsg == "" { - if e := checkCBErrorRate(&rp.BackupPolicy.StopPolicy.CBPolicy); e != nil { - rp.BackupPolicy.StopPolicy.CBPolicy.ErrorRate = defaultCBErrRate - errMsg = fmt.Sprintf("backupRetryer %s, use default %0.2f", e.Error(), defaultCBErrRate) - klog.Warnf(errMsg) - } + if err == nil { + err = checkStopPolicy(&rp.BackupPolicy.StopPolicy, maxBackupRetryTimes, r) } r.Lock() defer r.Unlock() r.enable = rp.Enable if err != nil { - r.errMsg = errMsg + r.errMsg = err.Error() return err } r.policy = rp.BackupPolicy @@ -220,14 +213,11 @@ func (r *backupRetryer) AppendErrMsgIfNeeded(ctx context.Context, err error, ri func (r *backupRetryer) Dump() map[string]interface{} { r.RLock() defer r.RUnlock() + dm := map[string]interface{}{"enable": r.enable, "backup_request": r.policy} if r.errMsg != "" { - return map[string]interface{}{ - "enable": r.enable, - "backupRequest": r.policy, - "errMsg": r.errMsg, - } + dm["err_msg"] = r.errMsg } - return map[string]interface{}{"enable": r.enable, "backupRequest": r.policy} + return dm } // Type implements the Retryer interface. diff --git a/pkg/retry/backup_test.go b/pkg/retry/backup_test.go index 4193f2de24..ca09999883 100644 --- a/pkg/retry/backup_test.go +++ b/pkg/retry/backup_test.go @@ -17,9 +17,15 @@ package retry import ( + "context" + "sync/atomic" "testing" + "time" "github.com/cloudwego/thriftgo/pkg/test" + + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/stats" ) // test BackupPolicy string @@ -31,3 +37,65 @@ func TestBackupPolicy_String(t *testing.T) { "DisableChainStop:false DDLStop:false CBPolicy:{ErrorRate:0.3}} RetrySameNode:true}" test.Assert(t, r.String() == msg) } + +// test BackupPolicy call while rpcTime > delayTime +func TestBackupPolicyCall(t *testing.T) { + ctx := context.Background() + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{Wildcard: { + Enable: true, + Type: 1, + BackupPolicy: &BackupPolicy{ + RetryDelayMS: 30, + StopPolicy: StopPolicy{ + MaxRetryTimes: 2, + DisableChainStop: false, + CBPolicy: CBPolicy{ + ErrorRate: defaultCBErrRate, + }, + }, + }, + }}, nil) + test.Assert(t, err == nil, err) + + callTimes := int32(0) + firstRI := genRPCInfo() + secondRI := genRPCInfoWithRemoteTag(remoteTags) + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, firstRI) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + atomic.AddInt32(&callTimes, 1) + if atomic.LoadInt32(&callTimes) == 1 { + // mock timeout for the first request and get the response of the backup request. + time.Sleep(time.Millisecond * 50) + return firstRI, nil, nil + } + return secondRI, nil, nil + }, firstRI, nil) + test.Assert(t, err == nil, err) + test.Assert(t, atomic.LoadInt32(&callTimes) == 2) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) +} + +func TestBackupRetryWithRPCInfo(t *testing.T) { + // backup retry + ctx := context.Background() + rc := NewRetryContainer() + + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + rpcinfo.Record(ctx, ri, stats.RPCStart, nil) + + // call with retry policy + var callTimes int32 + policy := BuildBackupRequest(NewBackupPolicy(10)) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, true), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo) + test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo) + test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888") + test.Assert(t, atomic.LoadInt32(&callTimes) == 2) +} diff --git a/pkg/retry/failure.go b/pkg/retry/failure.go index ffb0fdf2ac..c20e15bbcc 100644 --- a/pkg/retry/failure.go +++ b/pkg/retry/failure.go @@ -28,6 +28,13 @@ import ( const maxFailureRetryTimes = 5 +// AllErrorRetry is common choice for ShouldResultRetry. +func AllErrorRetry() *ShouldResultRetry { + return &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return err != nil + }} +} + // NewFailurePolicy init default failure retry policy func NewFailurePolicy() *FailurePolicy { p := &FailurePolicy{ @@ -115,14 +122,104 @@ func (p *FailurePolicy) WithSpecifiedResultRetry(rr *ShouldResultRetry) { // String prints human readable information. func (p *FailurePolicy) String() string { return fmt.Sprintf("{StopPolicy:%+v BackOffPolicy:%+v RetrySameNode:%+v "+ - "ShouldResultRetry:{ErrorRetry:%t, RespRetry:%t}}", p.StopPolicy, p.BackOffPolicy, p.RetrySameNode, p.IsErrorRetryNonNil(), p.IsRespRetryNonNil()) + "ShouldResultRetry:{ErrorRetry:%t, RespRetry:%t}}", p.StopPolicy, p.BackOffPolicy, p.RetrySameNode, p.isErrorRetryWithCtxNonNil(), p.isRespRetryWithCtxNonNil()) } -// AllErrorRetry is common choice for ShouldResultRetry. -func AllErrorRetry() *ShouldResultRetry { - return &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - return err != nil - }} +// Equals to check if FailurePolicy is equal +func (p *FailurePolicy) Equals(np *FailurePolicy) bool { + if p == nil { + return np == nil + } + if np == nil { + return false + } + if p.StopPolicy != np.StopPolicy { + return false + } + if !p.BackOffPolicy.Equals(np.BackOffPolicy) { + return false + } + if p.RetrySameNode != np.RetrySameNode { + return false + } + if p.Extra != np.Extra { + return false + } + // don't need to check `ShouldResultRetry`, ShouldResultRetry is only setup by option + // in remote config case will always return false if check it + return true +} + +func (p *FailurePolicy) DeepCopy() *FailurePolicy { + if p == nil { + return nil + } + return &FailurePolicy{ + StopPolicy: p.StopPolicy, + BackOffPolicy: p.BackOffPolicy.DeepCopy(), + RetrySameNode: p.RetrySameNode, + ShouldResultRetry: p.ShouldResultRetry, // don't need DeepCopy + Extra: p.Extra, + } +} + +// isRespRetryWithCtxNonNil is used to check if RespRetryWithCtx is nil. +func (p *FailurePolicy) isRespRetryWithCtxNonNil() bool { + return p.ShouldResultRetry != nil && p.ShouldResultRetry.RespRetryWithCtx != nil +} + +// isErrorRetryWithCtxNonNil is used to check if ErrorRetryWithCtx is nil +func (p *FailurePolicy) isErrorRetryWithCtxNonNil() bool { + return p.ShouldResultRetry != nil && p.ShouldResultRetry.ErrorRetryWithCtx != nil +} + +// isRespRetryNonNil is used to check if RespRetry is nil. +// Deprecated: please use isRespRetryWithCtxNonNil instead of isRespRetryNonNil. +func (p *FailurePolicy) isRespRetryNonNil() bool { + return p.ShouldResultRetry != nil && p.ShouldResultRetry.RespRetry != nil +} + +// isErrorRetryNonNil is used to check if ErrorRetry is nil. +// Deprecated: please use IsErrorRetryWithCtxNonNil instead of isErrorRetryNonNil. +func (p *FailurePolicy) isErrorRetryNonNil() bool { + return p.ShouldResultRetry != nil && p.ShouldResultRetry.ErrorRetry != nil +} + +// isRetryForTimeout is used to check if timeout error need to retry +func (p *FailurePolicy) isRetryForTimeout() bool { + return p.ShouldResultRetry == nil || !p.ShouldResultRetry.NotRetryForTimeout +} + +// isRespRetry is used to check if the resp need to do retry. +func (p *FailurePolicy) isRespRetry(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + // note: actually, it is better to check IsRespRetry to ignore the bad cases, + // but IsRespRetry is a deprecated definition and here will be executed for every call, depends on ConvertResultRetry to ensure the compatibility + return p.isRespRetryWithCtxNonNil() && p.ShouldResultRetry.RespRetryWithCtx(ctx, resp, ri) +} + +// isErrorRetry is used to check if the error need to do retry. +func (p *FailurePolicy) isErrorRetry(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + // note: actually, it is better to check IsErrorRetry to ignore the bad cases, + // but IsErrorRetry is a deprecated definition and here will be executed for every call, depends on ConvertResultRetry to ensure the compatibility + return p.isErrorRetryWithCtxNonNil() && p.ShouldResultRetry.ErrorRetryWithCtx(ctx, err, ri) +} + +// convertResultRetry is used to convert 'ErrorRetry and RespRetry' to 'ErrorRetryWithCtx and RespRetryWithCtx' +func (p *FailurePolicy) convertResultRetry() { + if p == nil || p.ShouldResultRetry == nil { + return + } + rr := p.ShouldResultRetry + if rr.ErrorRetry != nil && rr.ErrorRetryWithCtx == nil { + rr.ErrorRetryWithCtx = func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return rr.ErrorRetry(err, ri) + } + } + if rr.RespRetry != nil && rr.RespRetryWithCtx == nil { + rr.RespRetryWithCtx = func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + return rr.RespRetry(resp, ri) + } + } } // BackOff is the interface of back off implements diff --git a/pkg/retry/failure_retryer.go b/pkg/retry/failure_retryer.go index 11686eb6c4..694855f0d0 100644 --- a/pkg/retry/failure_retryer.go +++ b/pkg/retry/failure_retryer.go @@ -33,7 +33,7 @@ import ( ) func newFailureRetryer(policy Policy, r *ShouldResultRetry, cbC *cbContainer) (Retryer, error) { - fr := &failureRetryer{specifiedResultRetry: r, cbContainer: cbC} + fr := &failureRetryer{failureCommon: &failureCommon{specifiedResultRetry: r, cbContainer: cbC}} if err := fr.UpdatePolicy(policy); err != nil { return nil, fmt.Errorf("newfailureRetryer failed, err=%w", err) } @@ -41,30 +41,22 @@ func newFailureRetryer(policy Policy, r *ShouldResultRetry, cbC *cbContainer) (R } type failureRetryer struct { - enable bool - policy *FailurePolicy - backOff BackOff - cbContainer *cbContainer - specifiedResultRetry *ShouldResultRetry + enable bool + *failureCommon + policy *FailurePolicy sync.RWMutex errMsg string } -// ShouldRetry implements the Retryer interface. +// ShouldRetry to check if retry request can be called, it is checked in retryer.Do. +// If not satisfy will return the reason message func (r *failureRetryer) ShouldRetry(ctx context.Context, err error, callTimes int, req interface{}, cbKey string) (string, bool) { r.RLock() defer r.RUnlock() if !r.enable { return "", false } - if stop, msg := circuitBreakerStop(ctx, r.policy.StopPolicy, r.cbContainer, req, cbKey); stop { - return msg, false - } - if stop, msg := ddlStop(ctx, r.policy.StopPolicy); stop { - return msg, false - } - r.backOff.Wait(callTimes) - return "", true + return r.shouldRetry(ctx, callTimes, req, cbKey, r.policy) } // AllowRetry implements the Retryer interface. @@ -109,8 +101,8 @@ func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rp if i == 0 { callStart = startTime } else if i > 0 { - if maxDuration > 0 && time.Since(startTime) > maxDuration { - err = makeRetryErr(ctx, "exceed max duration", callTimes) + if ret, e := isExceedMaxDuration(ctx, startTime, maxDuration, callTimes); ret { + err = e break } if msg, ok := r.ShouldRetry(ctx, err, i, req, cbKey); !ok { @@ -137,20 +129,8 @@ func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rp if !r.cbContainer.enablePercentageLimit && r.cbContainer.cbStat { circuitbreak.RecordStat(ctx, req, nil, err, cbKey, r.cbContainer.cbCtl, r.cbContainer.cbPanel) } - if err == nil { - if r.policy.IsRespRetry(ctx, resp, cRI) { - // user specified resp to do retry - continue - } + if !r.isRetryResult(ctx, cRI, resp, err, r.policy) { break - } else { - if i == retryTimes { - // stop retry then wrap error - err = kerrors.ErrRetry.WithCause(err) - } else if !r.isRetryErr(ctx, err, cRI) { - // not timeout or user specified error won't do retry - break - } } } recordRetryInfo(cRI, callTimes, callCosts.String()) @@ -168,32 +148,21 @@ func (r *failureRetryer) UpdatePolicy(rp Policy) (err error) { r.Unlock() return nil } - var errMsg string if rp.FailurePolicy == nil || rp.Type != FailureType { - errMsg = "FailurePolicy is nil or retry type not match, cannot do update in failureRetryer" - err = errors.New(errMsg) + err = errors.New("FailurePolicy is nil or retry type not match, cannot do update in failureRetryer") } - rt := rp.FailurePolicy.StopPolicy.MaxRetryTimes - if errMsg == "" && (rt < 0 || rt > maxFailureRetryTimes) { - errMsg = fmt.Sprintf("invalid failure MaxRetryTimes[%d]", rt) - err = errors.New(errMsg) - } - if errMsg == "" { - if e := checkCBErrorRate(&rp.FailurePolicy.StopPolicy.CBPolicy); e != nil { - rp.FailurePolicy.StopPolicy.CBPolicy.ErrorRate = defaultCBErrRate - errMsg = fmt.Sprintf("failureRetryer %s, use default %0.2f", e.Error(), defaultCBErrRate) - klog.Warnf(errMsg) - } + if err == nil { + err = checkStopPolicy(&rp.FailurePolicy.StopPolicy, maxFailureRetryTimes, r) } r.Lock() defer r.Unlock() r.enable = rp.Enable if err != nil { - r.errMsg = errMsg + r.errMsg = err.Error() return err } r.policy = rp.FailurePolicy - r.setSpecifiedResultRetryIfNeeded(r.specifiedResultRetry) + r.setSpecifiedResultRetryIfNeeded(r.specifiedResultRetry, r.policy) if bo, e := initBackOff(rp.FailurePolicy.BackOffPolicy); e != nil { r.errMsg = fmt.Sprintf("failureRetryer update BackOffPolicy failed, err=%s", e.Error()) klog.Warnf(r.errMsg) @@ -205,7 +174,7 @@ func (r *failureRetryer) UpdatePolicy(rp Policy) (err error) { // AppendErrMsgIfNeeded implements the Retryer interface. func (r *failureRetryer) AppendErrMsgIfNeeded(ctx context.Context, err error, ri rpcinfo.RPCInfo, msg string) { - if r.isRetryErr(ctx, err, ri) { + if r.isRetryErr(ctx, err, ri, r.policy) { // Add additional reason when retry is not applied. appendErrMsg(err, msg) } @@ -216,7 +185,53 @@ func (r *failureRetryer) Prepare(ctx context.Context, prevRI, retryRI rpcinfo.RP handleRetryInstance(r.policy.RetrySameNode, prevRI, retryRI) } -func (r *failureRetryer) isRetryErr(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { +// Type implements the Retryer interface. +func (r *failureRetryer) Type() Type { + return FailureType +} + +// Dump implements the Retryer interface. +func (r *failureRetryer) Dump() map[string]interface{} { + r.RLock() + defer r.RUnlock() + dm := make(map[string]interface{}) + dm["enable"] = r.enable + dm["failure_retry"] = r.policy + if r.policy != nil { + dm["specified_result_retry"] = r.dumpSpecifiedResultRetry(*r.policy) + } + if r.errMsg != "" { + dm["err_msg"] = r.errMsg + } + return dm +} + +type failureCommon struct { + backOff BackOff + specifiedResultRetry *ShouldResultRetry + cbContainer *cbContainer +} + +func (f *failureCommon) setSpecifiedResultRetryIfNeeded(rr *ShouldResultRetry, fp *FailurePolicy) { + if rr != nil { + // save the object specified by client.WithSpecifiedResultRetry(..) + f.specifiedResultRetry = rr + } + if fp != nil { + if f.specifiedResultRetry != nil { + // The priority of client.WithSpecifiedResultRetry(..) is higher, so always update it + // NOTE: client.WithSpecifiedResultRetry(..) will always reject a nil object + fp.ShouldResultRetry = f.specifiedResultRetry + } + + // even though rr passed from this func is nil, + // the Policy may also have ShouldResultRetry from client.WithFailureRetry or callopt.WithRetryPolicy. + // convertResultRetry is used to convert 'ErrorRetry and RespRetry' to 'ErrorRetryWithCtx and RespRetryWithCtx' + fp.convertResultRetry() + } +} + +func (r *failureCommon) isRetryErr(ctx context.Context, err error, ri rpcinfo.RPCInfo, fp *FailurePolicy) bool { if err == nil { return false } @@ -225,15 +240,53 @@ func (r *failureRetryer) isRetryErr(ctx context.Context, err error, ri rpcinfo.R // But CircuitBreak has been checked in ShouldRetry, it doesn't need to filter ServiceCircuitBreak. // If there are some other specified errors that cannot be retried, it should be filtered here. - if r.policy.IsRetryForTimeout() && kerrors.IsTimeoutError(err) { + if fp.isRetryForTimeout() && kerrors.IsTimeoutError(err) { + return true + } + if fp.isErrorRetry(ctx, err, ri) { return true } - if r.policy.IsErrorRetry(ctx, err, ri) { + return false +} + +func (r *failureCommon) shouldRetry(ctx context.Context, callTimes int, req interface{}, cbKey string, fp *FailurePolicy) (string, bool) { + if stop, msg := circuitBreakerStop(ctx, fp.StopPolicy, r.cbContainer, req, cbKey); stop { + return msg, false + } + if stop, msg := ddlStop(ctx, fp.StopPolicy); stop { + return msg, false + } + r.backOff.Wait(callTimes) + return "", true +} + +// isRetryResult to check if the result need to do retry +// Version Change Note: +// < v0.11.0 if the last result still failed, then wrap the error as RetryErr +// >= v0.11.0 don't wrap RetryErr. +// Consideration: Wrap as RetryErr will be reflected as a retry error from monitoring, which is not friendly for troubleshooting +func (r *failureCommon) isRetryResult(ctx context.Context, cRI rpcinfo.RPCInfo, resp interface{}, err error, fp *FailurePolicy) bool { + if err == nil { + if fp.isRespRetry(ctx, resp, cRI) { + // user specified resp to do retry + return true + } + } else if r.isRetryErr(ctx, err, cRI, fp) { return true } return false } +func (r *failureCommon) dumpSpecifiedResultRetry(fp FailurePolicy) map[string]bool { + return map[string]bool{ + "error_retry": fp.isErrorRetryWithCtxNonNil(), + "resp_retry": fp.isRespRetryWithCtxNonNil(), + // keep it for some versions to confirm the correctness when troubleshooting + "old_error_retry": fp.isErrorRetryNonNil(), + "old_resp_retry": fp.isRespRetryNonNil(), + } +} + func initBackOff(policy *BackOffPolicy) (bo BackOff, err error) { bo = NoneBackOff if policy == nil { @@ -269,48 +322,9 @@ func initBackOff(policy *BackOffPolicy) (bo BackOff, err error) { return } -// Type implements the Retryer interface. -func (r *failureRetryer) Type() Type { - return FailureType -} - -// Dump implements the Retryer interface. -func (r *failureRetryer) Dump() map[string]interface{} { - r.RLock() - defer r.RUnlock() - dm := make(map[string]interface{}) - dm["enable"] = r.enable - dm["failure_retry"] = r.policy - if r.policy != nil { - dm["specified_result_retry"] = map[string]bool{ - "error_retry": r.policy.IsErrorRetryWithCtxNonNil(), - "resp_retry": r.policy.IsRespRetryWithCtxNonNil(), - // keep it for some versions to confirm the correctness when troubleshooting - "old_error_retry": r.policy.IsErrorRetryNonNil(), - "old_resp_retry": r.policy.IsRespRetryNonNil(), - } - } - if r.errMsg != "" { - dm["errMsg"] = r.errMsg - } - return dm -} - -func (r *failureRetryer) setSpecifiedResultRetryIfNeeded(rr *ShouldResultRetry) { - if rr != nil { - // save the object specified by client.WithSpecifiedResultRetry(..) - r.specifiedResultRetry = rr - } - if r.policy != nil { - if r.specifiedResultRetry != nil { - // The priority of client.WithSpecifiedResultRetry(..) is higher, so always update it - // NOTE: client.WithSpecifiedResultRetry(..) will always reject a nil object - r.policy.ShouldResultRetry = r.specifiedResultRetry - } - - // even though rr passed from this func is nil, - // the Policy may also have ShouldResultRetry from client.WithFailureRetry or callopt.WithRetryPolicy. - // convertResultRetry is used to convert 'ErrorRetry and RespRetry' to 'ErrorRetryWithCtx and RespRetryWithCtx' - r.policy.ConvertResultRetry() +func isExceedMaxDuration(ctx context.Context, start time.Time, maxDuration time.Duration, callTimes int32) (bool, error) { + if maxDuration > 0 && time.Since(start) > maxDuration { + return true, makeRetryErr(ctx, fmt.Sprintf("exceed max duration[%v]", maxDuration), callTimes) } + return false, nil } diff --git a/pkg/retry/failure_test.go b/pkg/retry/failure_test.go index fe1bf87bd1..376342442e 100644 --- a/pkg/retry/failure_test.go +++ b/pkg/retry/failure_test.go @@ -17,10 +17,16 @@ package retry import ( + "context" + "sync/atomic" "testing" "time" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/stats" ) func BenchmarkRandomBackOff_Wait(b *testing.B) { @@ -91,3 +97,597 @@ func TestNoneBackOff_String(t *testing.T) { msg := "NoneBackOff" test.Assert(t, bk.String() == msg) } + +// test FailurePolicy call +func TestFailurePolicyCall(t *testing.T) { + // call while rpc timeout + ctx := context.Background() + rc := NewRetryContainer() + failurePolicy := NewFailurePolicy() + failurePolicy.BackOffPolicy.BackOffType = FixedBackOffType + failurePolicy.BackOffPolicy.CfgItems = map[BackOffCfgKey]float64{ + FixMSBackOffCfgKey: 100.0, + } + failurePolicy.StopPolicy.MaxDurationMS = 100 + err := rc.Init(map[string]Policy{Wildcard: { + Enable: true, + Type: 0, + FailurePolicy: failurePolicy, + }}, nil) + test.Assert(t, err == nil, err) + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + return ri, nil, kerrors.ErrRPCTimeout + }, ri, nil) + test.Assert(t, err != nil, err) + test.Assert(t, !ok) + + // call normal + failurePolicy.StopPolicy.MaxDurationMS = 0 + err = rc.Init(map[string]Policy{Wildcard: { + Enable: true, + Type: 0, + FailurePolicy: failurePolicy, + }}, nil) + test.Assert(t, err == nil, err) + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + return ri, nil, nil + }, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, ok) +} + +// test retry with one time policy +func TestRetryWithOneTimePolicy(t *testing.T) { + // call while rpc timeout and exceed MaxDurationMS cause BackOffPolicy is wait fix 100ms, it is invalid config + failurePolicy := NewFailurePolicy() + failurePolicy.BackOffPolicy.BackOffType = FixedBackOffType + failurePolicy.BackOffPolicy.CfgItems = map[BackOffCfgKey]float64{ + FixMSBackOffCfgKey: 100.0, + } + failurePolicy.StopPolicy.MaxDurationMS = 100 + p := Policy{ + Enable: true, + Type: 0, + FailurePolicy: failurePolicy, + } + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + _, ok, err := NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + return ri, nil, kerrors.ErrRPCTimeout + }, ri, nil) + test.Assert(t, err != nil, err) + test.Assert(t, !ok) + + // call no MaxDurationMS limit, the retry will success + failurePolicy.StopPolicy.MaxDurationMS = 0 + var callTimes int32 + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo()) + _, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + if atomic.LoadInt32(&callTimes) == 0 { + atomic.AddInt32(&callTimes, 1) + return ri, nil, kerrors.ErrRPCTimeout + } + return ri, nil, nil + }, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + + // call backup request + p = BuildBackupRequest(NewBackupPolicy(10)) + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo()) + callTimes = 0 + _, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + if atomic.LoadInt32(&callTimes) == 0 || atomic.LoadInt32(&callTimes) == 1 { + atomic.AddInt32(&callTimes, 1) + time.Sleep(time.Millisecond * 100) + } + return ri, nil, nil + }, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + test.Assert(t, atomic.LoadInt32(&callTimes) == 2) +} + +// test specified error to retry +func TestSpecifiedErrorRetry(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + var transErrCode int32 = 1000 + + // case1: specified method retry with error + t.Run("case1", func(t *testing.T) { + rc := NewRetryContainer() + checkResultRetry := false + shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == transErrCode { + checkResultRetry = true + return true + } + } + return false + }} + err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, checkResultRetry) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) + + // case2: specified method retry with error, but use backup request config cannot be effective + t.Run("case2", func(t *testing.T) { + shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == transErrCode { + return true + } + } + return false + }} + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err != nil, err) + test.Assert(t, !ok) + }) + + // case3: specified method retry with error, but method not match + t.Run("case3", func(t *testing.T) { + shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() != method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == transErrCode { + return true + } + } + return false + }} + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err != nil) + test.Assert(t, !ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) + }) + + // case4: all error retry + t.Run("case4", func(t *testing.T) { + rc := NewRetryContainer() + p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry())) + ri = genRPCInfo() + ri, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) +} + +// test specified resp to retry +func TestSpecifiedRespRetry(t *testing.T) { + retryResult := &mockResult{} + retryResp := mockResp{ + code: 500, + msg: "retry", + } + noRetryResp := mockResp{ + code: 0, + msg: "noretry", + } + var callTimes int32 + retryWithResp := func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + newVal := atomic.AddInt32(&callTimes, 1) + if newVal == 1 { + retryResult.setResult(retryResp) + return genRPCInfo(), retryResult, nil + } else { + retryResult.setResult(noRetryResp) + return genRPCInfoWithRemoteTag(remoteTags), retryResult, nil + } + } + ctx := context.Background() + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + rc := NewRetryContainer() + // case1: specified method retry with resp + shouldResultRetry := &ShouldResultRetry{RespRetry: func(resp interface{}, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if r, ok := resp.(*mockResult); ok && r.getResult() == retryResp { + return true + } + } + return false + }} + err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == noRetryResp, retryResult) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + + // case2 specified method retry with resp, but use backup request config cannot be effective + atomic.StoreInt32(&callTimes, 0) + rc = NewRetryContainer() + err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(100))}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == retryResp, retryResp) + test.Assert(t, !ok) + + // case3: specified method retry with resp, but method not match + atomic.StoreInt32(&callTimes, 0) + shouldResultRetry = &ShouldResultRetry{RespRetry: func(resp interface{}, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() != method { + if r, ok := resp.(*mockResult); ok && r.getResult() == retryResp { + return true + } + } + return false + }} + rc = NewRetryContainer() + err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == retryResp, retryResult) + test.Assert(t, ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) +} + +// test specified error to retry with ErrorRetryWithCtx +func TestSpecifiedErrorRetryWithCtx(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + var transErrCode int32 = 1000 + + // case1: specified method retry with error + t.Run("case1", func(t *testing.T) { + rc := NewRetryContainer() + shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == transErrCode { + return true + } + } + return false + }} + err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) + + // case2: specified method retry with error, but use backup request config cannot be effective + t.Run("case2", func(t *testing.T) { + shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == transErrCode { + return true + } + } + return false + }} + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err != nil, err) + test.Assert(t, !ok) + }) + + // case3: specified method retry with error, but method not match + t.Run("case3", func(t *testing.T) { + shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return ri.To().Method() != method + }} + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err != nil) + test.Assert(t, !ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) + }) + + // case4: all error retry + t.Run("case4", func(t *testing.T) { + rc := NewRetryContainer() + p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry())) + ri = genRPCInfo() + ri, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) + + // case5: specified method retry with error, only ctx has some info then retry + ctxKeyVal := "ctxKeyVal" + t.Run("case5", func(t *testing.T) { + shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method && ctx.Value(ctxKeyVal) == ctxKeyVal { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { + return true + } + } + return false + }} + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = context.WithValue(ctx, ctxKeyVal, ctxKeyVal) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) +} + +// test specified error to retry, but has both old and new policy, the new one will be effective +func TestSpecifiedErrorRetryHasOldAndNew(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + // case1: ErrorRetryWithCtx will retry, but ErrorRetry not retry, the expect result is do retry + t.Run("case1", func(t *testing.T) { + rc := NewRetryContainer() + shouldResultRetry := &ShouldResultRetry{ + ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return true + }, + ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { + return false + }, + } + err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, 1000), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) + + // case2: ErrorRetryWithCtx not retry, but ErrorRetry retry, the expect result is that not do retry + t.Run("case2", func(t *testing.T) { + shouldResultRetry := &ShouldResultRetry{ + ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return false + }, + ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { + return true + }, + } + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, 1000), ri, nil) + test.Assert(t, err != nil) + test.Assert(t, !ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) + }) +} + +// test specified resp to retry with ErrorRetryWithCtx +func TestSpecifiedRespRetryWithCtx(t *testing.T) { + retryResult := &mockResult{} + retryResp := mockResp{ + code: 500, + msg: "retry", + } + noRetryResp := mockResp{ + code: 0, + msg: "noretry", + } + var callTimes int32 + retryWithResp := func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + newVal := atomic.AddInt32(&callTimes, 1) + if newVal == 1 { + retryResult.setResult(retryResp) + return genRPCInfo(), retryResult, nil + } else { + retryResult.setResult(noRetryResp) + return genRPCInfoWithRemoteTag(remoteTags), retryResult, nil + } + } + ctx := context.Background() + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + rc := NewRetryContainer() + // case1: specified method retry with resp + shouldResultRetry := &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if r, ok := resp.(*mockResult); ok && r.getResult() == retryResp { + return true + } + } + return false + }} + err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == noRetryResp, retryResult) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + + // case2 specified method retry with resp, but use backup request config cannot be effective + atomic.StoreInt32(&callTimes, 0) + rc = NewRetryContainer() + err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(100))}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == retryResp, retryResp) + test.Assert(t, !ok) + + // case3: specified method retry with resp, but method not match + atomic.StoreInt32(&callTimes, 0) + shouldResultRetry = &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() != method { + if r, ok := resp.(*mockResult); ok && r.getResult() == retryResp { + return true + } + } + return false + }} + rc = NewRetryContainer() + err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == retryResp, retryResult) + test.Assert(t, ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) + + // case4: specified method retry with resp, only ctx has some info then retry + ctxKeyVal := "ctxKeyVal" + atomic.StoreInt32(&callTimes, 0) + shouldResultRetry2 := &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method && ctx.Value(ctxKeyVal) == ctxKeyVal { + if r, ok := resp.(*mockResult); ok && r.getResult() == retryResp { + return true + } + } + return false + }} + err = rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry2) + test.Assert(t, err == nil, err) + ctx = context.WithValue(ctx, ctxKeyVal, ctxKeyVal) + ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, retryResult.getResult() == noRetryResp, retryResult) + test.Assert(t, !ok) + v, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) +} + +func TestResultRetryWithPolicyChange(t *testing.T) { + rc := NewRetryContainer() + shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { + return true + } + } + return false + }} + err := rc.Init(nil, shouldResultRetry) + test.Assert(t, err == nil, err) + + // case 1: first time trigger NotifyPolicyChange, the `initRetryer` will be executed, check if the ShouldResultRetry is not nil + rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) + r := rc.getRetryer(context.Background(), genRPCInfo()) + fr, ok := r.(*failureRetryer) + test.Assert(t, ok) + test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) + + // case 2: second time trigger NotifyPolicyChange, the `UpdatePolicy` will be executed, check if the ShouldResultRetry is not nil + rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) + r = rc.getRetryer(context.Background(), genRPCInfo()) + fr, ok = r.(*failureRetryer) + test.Assert(t, ok) + test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) +} + +func TestResultRetryWithCtxWhenPolicyChange(t *testing.T) { + rc := NewRetryContainer() + shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { + return true + } + } + return false + }} + err := rc.Init(nil, shouldResultRetry) + test.Assert(t, err == nil, err) + + // case 1: first time trigger NotifyPolicyChange, the `initRetryer` will be executed, check if the ShouldResultRetry is not nil + rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) + r := rc.getRetryer(context.Background(), genRPCInfo()) + fr, ok := r.(*failureRetryer) + test.Assert(t, ok) + test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) + + // case 2: second time trigger NotifyPolicyChange, the `UpdatePolicy` will be executed, check if the ShouldResultRetry is not nil + rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) + r = rc.getRetryer(context.Background(), genRPCInfo()) + fr, ok = r.(*failureRetryer) + test.Assert(t, ok) + test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) +} + +func TestFailureRetryWithRPCInfo(t *testing.T) { + // failure retry + ctx := context.Background() + rc := NewRetryContainer() + + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + rpcinfo.Record(ctx, ri, stats.RPCStart, nil) + + // call with retry policy + var callTimes int32 + policy := BuildFailurePolicy(NewFailurePolicy()) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, false), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo) + test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo) + test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888") + test.Assert(t, atomic.LoadInt32(&callTimes) == 2) +} + +var retryWithTransError = func(callTimes, transErrCode int32) RPCCallFunc { + // fails for the first call if callTimes is initialized to 0 + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + if atomic.AddInt32(&callTimes, 1) == 1 { + // first call retry TransErr with specified errCode + return genRPCInfo(), nil, remote.NewTransErrorWithMsg(transErrCode, "mock") + } else { + return genRPCInfoWithRemoteTag(remoteTags), nil, nil + } + } +} diff --git a/pkg/retry/mixed.go b/pkg/retry/mixed.go new file mode 100644 index 0000000000..103c7df0b4 --- /dev/null +++ b/pkg/retry/mixed.go @@ -0,0 +1,82 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import "fmt" + +const maxMixRetryTimes = 3 + +// NewMixedPolicy init default mixed retry policy +func NewMixedPolicy(delayMS uint32) *MixedPolicy { + if delayMS == 0 { + panic("invalid backup request delay duration in MixedPolicy") + } + p := &MixedPolicy{ + RetryDelayMS: delayMS, + FailurePolicy: FailurePolicy{ + StopPolicy: StopPolicy{ + MaxRetryTimes: defaultBackupRetryTimes, + DisableChainStop: false, + CBPolicy: CBPolicy{ + ErrorRate: defaultCBErrRate, + }, + }, + BackOffPolicy: &BackOffPolicy{BackOffType: NoneBackOffType}, + }, + } + return p +} + +// NewMixedPolicyWithResultRetry init failure retry policy with ShouldResultRetry +func NewMixedPolicyWithResultRetry(delayMS uint32, rr *ShouldResultRetry) *MixedPolicy { + fp := NewMixedPolicy(delayMS) + fp.ShouldResultRetry = rr + return fp +} + +// String is used to print human readable debug info. +func (p *MixedPolicy) String() string { + return fmt.Sprintf("{RetryDelayMS:%+v StopPolicy:%+v BackOffPolicy:%+v RetrySameNode:%+v "+ + "ShouldResultRetry:{ErrorRetry:%t, RespRetry:%t}}", p.RetryDelayMS, p.StopPolicy, p.BackOffPolicy, p.RetrySameNode, p.isErrorRetryWithCtxNonNil(), p.isRespRetryWithCtxNonNil()) +} + +// Equals to check if MixedPolicy is equal +func (p *MixedPolicy) Equals(np *MixedPolicy) bool { + if p == nil { + return np == nil + } + if np == nil { + return false + } + if p.RetryDelayMS != np.RetryDelayMS { + return false + } + if !p.FailurePolicy.Equals(&np.FailurePolicy) { + return false + } + return true +} + +func (p *MixedPolicy) DeepCopy() *MixedPolicy { + if p == nil { + return nil + } + return &MixedPolicy{ + RetryDelayMS: p.RetryDelayMS, + FailurePolicy: *p.FailurePolicy.DeepCopy(), + } +} diff --git a/pkg/retry/mixed_retryer.go b/pkg/retry/mixed_retryer.go new file mode 100644 index 0000000000..496f7061cd --- /dev/null +++ b/pkg/retry/mixed_retryer.go @@ -0,0 +1,270 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/kitex/pkg/circuitbreak" + "github.com/cloudwego/kitex/pkg/gofunc" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/utils" +) + +func newMixedRetryer(policy Policy, r *ShouldResultRetry, cbC *cbContainer) (Retryer, error) { + mr := &mixedRetryer{failureCommon: &failureCommon{specifiedResultRetry: r, cbContainer: cbC}} + if err := mr.UpdatePolicy(policy); err != nil { + return nil, fmt.Errorf("newMixedRetryer failed, err=%w", err) + } + return mr, nil +} + +type mixedRetryer struct { + enable bool + *failureCommon + policy *MixedPolicy + retryDelay time.Duration + sync.RWMutex + errMsg string +} + +// ShouldRetry to check if retry request can be called, it is checked in retryer.Do. +// If not satisfy will return the reason message +// Actually, the ShouldRetry logic is same with failureRetryer, because +func (r *mixedRetryer) ShouldRetry(ctx context.Context, err error, callTimes int, req interface{}, cbKey string) (string, bool) { + r.RLock() + defer r.RUnlock() + if !r.enable { + return "", false + } + return r.shouldRetry(ctx, callTimes, req, cbKey, &r.policy.FailurePolicy) +} + +// AllowRetry implements the Retryer interface. +func (r *mixedRetryer) AllowRetry(ctx context.Context) (string, bool) { + r.RLock() + defer r.RUnlock() + if !r.enable || r.policy.StopPolicy.MaxRetryTimes == 0 { + return "", false + } + if stop, msg := chainStop(ctx, r.policy.StopPolicy); stop { + return msg, false + } + if stop, msg := ddlStop(ctx, r.policy.StopPolicy); stop { + return msg, false + } + return "", true +} + +// Do implement the Retryer interface. +func (r *mixedRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, req interface{}) (lastRI rpcinfo.RPCInfo, recycleRI bool, err error) { + r.RLock() + var maxDuration time.Duration + if r.policy.StopPolicy.MaxDurationMS > 0 { + maxDuration = time.Duration(r.policy.StopPolicy.MaxDurationMS) * time.Millisecond + } + retryTimes := r.policy.StopPolicy.MaxRetryTimes + retryDelay := r.retryDelay + r.RUnlock() + + var callTimes int32 + var callCosts utils.StringBuilder + callCosts.RawStringBuilder().Grow(32) + + var recordCostDoing int32 = 0 + var abort int32 = 0 + doneCount := 0 + finishedErrCount := 0 + + // notice: buff num of chan is very important here, it cannot less than call times, or the below chan receive will block + callDone := make(chan *resultWrapper, retryTimes+1) + var nonFinishedErrRes *resultWrapper + timer := time.NewTimer(retryDelay) + cbKey, _ := r.cbContainer.cbCtl.GetKey(ctx, req) + defer func() { + if panicInfo := recover(); panicInfo != nil { + err = panicToErr(ctx, panicInfo, firstRI) + } + timer.Stop() + }() + startTime := time.Now() + // include first call, max loop is retryTimes + 1 + doCall := true + for callCount := 0; ; { + if doCall { + doCall = false + if callCount > 0 { + if ret, e := isExceedMaxDuration(ctx, startTime, maxDuration, atomic.LoadInt32(&callTimes)); ret { + return firstRI, false, e + } + } + callCount++ + gofunc.GoFunc(ctx, func() { + if atomic.LoadInt32(&abort) == 1 { + return + } + var ( + e error + cRI rpcinfo.RPCInfo + resp interface{} + ) + defer func() { + if panicInfo := recover(); panicInfo != nil { + e = panicToErr(ctx, panicInfo, firstRI) + } + callDone <- &resultWrapper{cRI, resp, e} + }() + ct := atomic.AddInt32(&callTimes, 1) + callStart := time.Now() + if r.cbContainer.enablePercentageLimit { + // record stat before call since requests may be slow, making the limiter more accurate + recordRetryStat(cbKey, r.cbContainer.cbPanel, ct) + } + cRI, resp, e = rpcCall(ctx, r) + recordCost(ct, callStart, &recordCostDoing, &callCosts, &abort, e) + if !r.cbContainer.enablePercentageLimit && r.cbContainer.cbStat { + circuitbreak.RecordStat(ctx, req, nil, e, cbKey, r.cbContainer.cbCtl, r.cbContainer.cbPanel) + } + }) + } + select { + case <-timer.C: + // backup retry + if _, ok := r.ShouldRetry(ctx, nil, callCount, req, cbKey); ok && callCount < retryTimes+1 { + doCall = true + timer.Reset(retryDelay) + } + case res := <-callDone: + // result retry + if respOp, ok := ctx.Value(CtxRespOp).(*int32); ok { + // must set as OpNo, or the new resp cannot be decoded + atomic.StoreInt32(respOp, OpNo) + } + doneCount++ + isFinishErr := res.err != nil && errors.Is(res.err, kerrors.ErrRPCFinish) + if nonFinishedErrRes == nil || !isFinishErr { + nonFinishedErrRes = res + } + if doneCount < retryTimes+1 { + if isFinishErr { + // There will be only one request (goroutine) pass the `checkRPCState`, others will skip decoding + // and return `ErrRPCFinish`, to avoid concurrent write to response and save the cost of decoding. + // We can safely ignore this error and wait for the response of the passed goroutine. + if finishedErrCount++; finishedErrCount >= retryTimes+1 { + // But if all requests return this error, it must be a bug, preventive panic to avoid dead loop + panic(errUnexpectedFinish) + } + continue + } + if callCount < retryTimes+1 { + if msg, ok := r.ShouldRetry(ctx, nil, callCount, req, cbKey); ok { + if r.isRetryResult(ctx, res.ri, res.resp, res.err, &r.policy.FailurePolicy) { + doCall = true + timer.Reset(retryDelay) + continue + } + } else if msg != "" { + appendMsg := fmt.Sprintf("retried %d, %s", callCount-1, msg) + appendErrMsg(res.err, appendMsg) + } + } else if r.isRetryResult(ctx, res.ri, res.resp, res.err, &r.policy.FailurePolicy) { + continue + } + } + atomic.StoreInt32(&abort, 1) + recordRetryInfo(nonFinishedErrRes.ri, atomic.LoadInt32(&callTimes), callCosts.String()) + return nonFinishedErrRes.ri, false, nonFinishedErrRes.err + } + } +} + +// UpdatePolicy implements the Retryer interface. +func (r *mixedRetryer) UpdatePolicy(rp Policy) (err error) { + if !rp.Enable { + r.Lock() + r.enable = rp.Enable + r.Unlock() + return nil + } + if rp.MixedPolicy == nil || rp.Type != MixedType { + err = errors.New("MixedPolicy is nil or retry type not match, cannot do update in mixedRetryer") + } + if err == nil && rp.MixedPolicy.RetryDelayMS == 0 { + err = errors.New("invalid retry delay duration in mixedRetryer") + } + if err == nil { + err = checkStopPolicy(&rp.MixedPolicy.StopPolicy, maxMixRetryTimes, r) + } + r.Lock() + defer r.Unlock() + r.enable = rp.Enable + if err != nil { + r.errMsg = err.Error() + return err + } + r.policy = rp.MixedPolicy + r.retryDelay = time.Duration(rp.MixedPolicy.RetryDelayMS) * time.Millisecond + r.setSpecifiedResultRetryIfNeeded(r.specifiedResultRetry, &r.policy.FailurePolicy) + if bo, e := initBackOff(rp.MixedPolicy.BackOffPolicy); e != nil { + r.errMsg = fmt.Sprintf("mixedRetryer update BackOffPolicy failed, err=%s", e.Error()) + klog.Warnf("KITEX: %s", r.errMsg) + } else { + r.backOff = bo + } + return nil +} + +// AppendErrMsgIfNeeded implements the Retryer interface. +func (r *mixedRetryer) AppendErrMsgIfNeeded(ctx context.Context, err error, ri rpcinfo.RPCInfo, msg string) { + if r.isRetryErr(ctx, err, ri, &r.policy.FailurePolicy) { + // Add additional reason when retry is not applied. + appendErrMsg(err, msg) + } +} + +// Prepare implements the Retryer interface. +func (r *mixedRetryer) Prepare(ctx context.Context, prevRI, retryRI rpcinfo.RPCInfo) { + handleRetryInstance(r.policy.RetrySameNode, prevRI, retryRI) +} + +// Type implements the Retryer interface. +func (r *mixedRetryer) Type() Type { + return MixedType +} + +// Dump implements the Retryer interface. +func (r *mixedRetryer) Dump() map[string]interface{} { + r.RLock() + defer r.RUnlock() + dm := make(map[string]interface{}) + dm["enable"] = r.enable + dm["mixed_retry"] = r.policy + if r.policy != nil { + dm["specified_result_retry"] = r.dumpSpecifiedResultRetry(r.policy.FailurePolicy) + } + if r.errMsg != "" { + dm["err_msg"] = r.errMsg + } + return dm +} diff --git a/pkg/retry/mixed_test.go b/pkg/retry/mixed_test.go new file mode 100644 index 0000000000..c8312e107f --- /dev/null +++ b/pkg/retry/mixed_test.go @@ -0,0 +1,625 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import ( + "context" + "errors" + "math" + "sync/atomic" + "testing" + "time" + + "github.com/bytedance/sonic" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +// test new MixedPolicy +func TestMixedRetryPolicy(t *testing.T) { + mp := NewMixedPolicy(100) + + // case 1 + mp.WithMaxRetryTimes(3) + jsonRet, err := sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var mp2 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &mp2) + test.Assert(t, err == nil, err) + test.Assert(t, mp.Equals(&mp2)) + test.Assert(t, mp2.FailurePolicy.StopPolicy.MaxRetryTimes == 3) + + // case 2 + mp.WithMaxRetryTimes(2) + mp.WithRetrySameNode() + mp.WithFixedBackOff(10) + jsonRet, err = sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var mp3 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &mp3) + test.Assert(t, err == nil, err) + test.Assert(t, mp.Equals(&mp3), mp3) + test.Assert(t, mp3.FailurePolicy.StopPolicy.MaxRetryTimes == 2) + test.Assert(t, mp3.FailurePolicy.BackOffPolicy.BackOffType == FixedBackOffType) + + // case 3 + mp.WithRandomBackOff(10, 20) + jsonRet, err = sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var mp4 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &mp4) + test.Assert(t, err == nil, err) + test.Assert(t, mp.Equals(&mp4), mp4) + test.Assert(t, mp4.FailurePolicy.StopPolicy.MaxRetryTimes == 2) + test.Assert(t, mp4.FailurePolicy.BackOffPolicy.BackOffType == RandomBackOffType) + + // case 4 + mp.WithRetryBreaker(0.2) + mp.WithDDLStop() + mp.WithMaxDurationMS(100) + jsonRet, err = sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var mp5 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &mp5) + test.Assert(t, err == nil, err) + test.Assert(t, mp.Equals(&mp5), mp5) + test.Assert(t, mp5.FailurePolicy.StopPolicy.DDLStop) + test.Assert(t, mp5.FailurePolicy.StopPolicy.MaxDurationMS == 100) + test.Assert(t, mp5.FailurePolicy.StopPolicy.CBPolicy.ErrorRate == 0.2) + + // case 5 + mp = &MixedPolicy{ + RetryDelayMS: 20, + FailurePolicy: FailurePolicy{ + StopPolicy: StopPolicy{ + MaxRetryTimes: 2, + DisableChainStop: false, + CBPolicy: CBPolicy{ + ErrorRate: defaultCBErrRate, + }, + }, + Extra: "{}", + }, + } + jsonRet, err = sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var mp6 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &mp6) + test.Assert(t, err == nil, err) + test.Assert(t, mp6.BackOffPolicy == nil) + test.Assert(t, mp.Equals(&mp6), mp6) + test.Assert(t, mp6.FailurePolicy.StopPolicy.MaxRetryTimes == 2) + test.Assert(t, !mp6.FailurePolicy.StopPolicy.DisableChainStop) + test.Assert(t, mp6.FailurePolicy.StopPolicy.CBPolicy.ErrorRate == defaultCBErrRate) + + // case 6 + mp.DisableChainRetryStop() + jsonRet, err = sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var mp7 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &mp7) + test.Assert(t, err == nil, err) + test.Assert(t, mp7.BackOffPolicy == nil) + test.Assert(t, mp.Equals(&mp7), mp7) + test.Assert(t, mp7.FailurePolicy.StopPolicy.DisableChainStop) + test.Assert(t, mp.String() == "{RetryDelayMS:20 StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true "+ + "DDLStop:false CBPolicy:{ErrorRate:0.1}} BackOffPolicy: RetrySameNode:false ShouldResultRetry:{ErrorRetry:false, RespRetry:false}}", mp) + + // case 7 + mp.WithSpecifiedResultRetry(&ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + return false + }}) + test.Assert(t, mp.String() == "{RetryDelayMS:20 StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true "+ + "DDLStop:false CBPolicy:{ErrorRate:0.1}} BackOffPolicy: RetrySameNode:false ShouldResultRetry:{ErrorRetry:true, RespRetry:false}}", mp) + jsonRet, err = sonic.MarshalString(mp) + test.Assert(t, err == nil, err) + var fp9 MixedPolicy + err = sonic.UnmarshalString(jsonRet, &fp9) + test.Assert(t, err == nil, err) + test.Assert(t, mp.Equals(&fp9), fp9) + test.Assert(t, fp9.ShouldResultRetry == nil) +} + +func TestNewMixedPolicy(t *testing.T) { + mp0 := NewMixedPolicy(100) + mp1 := NewMixedPolicy(100) + test.Assert(t, mp0.Equals(mp1)) + + mp1 = NewMixedPolicy(20) + test.Assert(t, !mp0.Equals(mp1)) + + mp1 = mp0.DeepCopy() + test.Assert(t, mp0.Equals(mp1)) + + mp1 = mp0.DeepCopy() + mp1.WithMaxRetryTimes(3) + test.Assert(t, !mp0.Equals(mp1)) + + mp1 = mp0.DeepCopy() + mp1.WithFixedBackOff(10) + test.Assert(t, !mp0.Equals(mp1)) + + mp1 = mp0.DeepCopy() + mp1.WithRetryBreaker(0.2) + test.Assert(t, !mp0.Equals(mp1)) + + mp1 = nil + test.Assert(t, !mp0.Equals(mp1)) + + mp0 = nil + test.Assert(t, mp0.Equals(mp1)) + + test.Panic(t, func() { NewMixedPolicy(0) }) +} + +// test MixedRetry call +func TestMixedRetry(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + var transErrCode int32 = 1001 + shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == transErrCode { + return true + } + } + return false + }} + + // case1: specified method retry with error + t.Run("specified method retry with error", func(t *testing.T) { + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{Wildcard: BuildMixedPolicy(NewMixedPolicy(100))}, shouldResultRetry) + test.Assert(t, err == nil, err) + ri, ok, err := rc.WithRetryIfNeeded(ctx, nil, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + + ri, ok, err = rc.WithRetryIfNeeded(ctx, nil, retryWithTransError(0, 1002), ri, nil) + test.Assert(t, err != nil) + test.Assert(t, !ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) + }) + + // case2: specified method retry with error, but method not match + t.Run("specified method retry with error, but method not match", func(t *testing.T) { + rr := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() != method { + if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { + return true + } + } + return false + }} + rc := NewRetryContainer() + err := rc.Init(map[string]Policy{method: BuildMixedPolicy(NewMixedPolicy(100))}, rr) + test.Assert(t, err == nil, err) + ri = genRPCInfo() + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err != nil) + test.Assert(t, !ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) + }) + + // case3: all error retry + t.Run("all error retry", func(t *testing.T) { + rc := NewRetryContainer() + p := BuildMixedPolicy(NewMixedPolicyWithResultRetry(100, AllErrorRetry())) + ri = genRPCInfo() + ri, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTransError(0, transErrCode), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) + }) + + // case4: RPCFinishErr + t.Run("RPCFinishErr", func(t *testing.T) { + mockErr := errors.New("mock") + retryWithRPCFinishErr := func(callCount int32) RPCCallFunc { + // fails for the first call if callTimes is initialized to 0 + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + time.Sleep(50 * time.Millisecond) + ct := atomic.AddInt32(&callCount, 1) + if ct == 1 || ct == 2 { + // first call retry TransErr with specified errCode + return genRPCInfo(), nil, mockErr + } else { + return genRPCInfo(), nil, kerrors.ErrRPCFinish + } + } + } + + rc := NewRetryContainer() + mp := NewMixedPolicyWithResultRetry(10, AllErrorRetry()) + mp.WithMaxRetryTimes(3) + p := BuildMixedPolicy(mp) + ri = genRPCInfo() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithRPCFinishErr(0), ri, nil) + test.Assert(t, err != nil, err) + test.Assert(t, err == mockErr, err) + test.Assert(t, !ok) + }) +} + +// Assuming the first request returns at 300ms, the second request costs 150ms +// Configuration: Timeout=200ms、MaxRetryTimes=2 BackupDelay=100ms +// - Mixed Retry: Success, cost 250ms +// - Failure Retry: Success, cost 350ms +// - Backup Retry: Failure, cost 200ms +func TestMockCase1WithDiffRetry(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + retryWithTimeout := func(ri rpcinfo.RPCInfo, callTimes int32, resp *mockResult) RPCCallFunc { + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + ct := atomic.AddInt32(&callTimes, 1) + resp.setCallTimes(ct) + if ct == 1 { + // first call retry timeout + time.Sleep(200 * time.Millisecond) + return ri, nil, kerrors.ErrRPCTimeout.WithCause(errors.New("mock")) + } else { + time.Sleep(150 * time.Millisecond) + return ri, resp, nil + } + } + } + // mixed retry will success, latency is lowest + t.Run("mixed retry", func(t *testing.T) { + rc := NewRetryContainer() + mp := NewMixedPolicy(100) + mp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildMixedPolicy(mp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) // 100+150 = 250 + test.Assert(t, err == nil, err) + test.Assert(t, ret.getCallTimes() == 3, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-250.0) < 50.0, cost.Milliseconds()) + }) + + // failure retry will success, but latency is more than mixed retry + t.Run("failure retry", func(t *testing.T) { + rc := NewRetryContainer() + fp := NewFailurePolicy() + fp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildFailurePolicy(fp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err == nil, err) + test.Assert(t, ret.callTimes == 2, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-350.0) < 50.0, cost.Milliseconds()) + }) + + // backup request will failure + t.Run("backup request", func(t *testing.T) { + rc := NewRetryContainer() + bp := NewBackupPolicy(100) + bp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildBackupRequest(bp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(err, kerrors.ErrRPCTimeout)) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-200.0) < 50.0, cost.Milliseconds()) + }) +} + +// Assuming the first request returns at 300ms, the second request cost 150ms +// Configuration: Timeout=300ms、MaxRetryTimes=2 BackupDelay=100ms +// - Mixed Retry: Success, cost 250ms (>timeout, same with Backup Retry) +// - Failure Retry: Success, cost 350ms +// - Backup Retry: Failure, cost 200ms (same with Mixed Retry) +func TestMockCase2WithDiffRetry(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + retryWithTimeout := func(ri rpcinfo.RPCInfo, callTimes int32, resp *mockResult) RPCCallFunc { + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + ct := atomic.AddInt32(&callTimes, 1) + resp.setCallTimes(ct) + if ct == 1 { + // first call retry timeout + time.Sleep(300 * time.Millisecond) + return ri, nil, kerrors.ErrRPCTimeout.WithCause(errors.New("mock")) + } else { + time.Sleep(150 * time.Millisecond) + return ri, resp, nil + } + } + } + // mixed retry will success, latency is lowest + t.Run("mixed retry", func(t *testing.T) { + rc := NewRetryContainer() + mp := NewMixedPolicy(100) + mp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildMixedPolicy(mp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) // 100+150 = 250 + test.Assert(t, err == nil, err) + test.Assert(t, ret.getCallTimes() == 3, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-250.0) < 50.0, cost.Milliseconds()) + }) + + // failure retry will success, but latency is more than mixed retry + t.Run("failure retry", func(t *testing.T) { + rc := NewRetryContainer() + fp := NewFailurePolicy() + fp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildFailurePolicy(fp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err == nil, err) + test.Assert(t, ret.getCallTimes() == 2, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-450.0) < 50.0, cost.Milliseconds()) + }) + + // backup request will failure + t.Run("backup request", func(t *testing.T) { + rc := NewRetryContainer() + bp := NewBackupPolicy(100) + bp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildBackupRequest(bp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err == nil, err) + test.Assert(t, ret.getCallTimes() == 3, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-250.0) < 50.0, cost.Milliseconds()) + }) +} + +// Assuming all request timeout +// Configuration: Timeout=100ms、MaxRetryTimes=2 BackupDelay=100ms +// - Mixed Retry: Failure, cost 200ms +// - Failure Retry: Failure, cost 300ms +// - Backup Retry: Failure, cost 100ms (max cost is timeout) +func TestMockCase3WithDiffRetry(t *testing.T) { + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + retryWithTimeout := func(ri rpcinfo.RPCInfo, callTimes int32, resp *mockResult) RPCCallFunc { + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + ct := atomic.AddInt32(&callTimes, 1) + resp.setCallTimes(ct) + time.Sleep(100 * time.Millisecond) + return ri, nil, kerrors.ErrRPCTimeout.WithCause(errors.New("mock")) + } + } + // mixed retry will success, cost is least + t.Run("mixed retry", func(t *testing.T) { + rc := NewRetryContainer() + mp := NewMixedPolicy(100) + mp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildMixedPolicy(mp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) // 100+(100,100) = 200 + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(err, kerrors.ErrRPCTimeout)) + test.Assert(t, ret.getCallTimes() == 3, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-200.0) < 50.0, cost.Milliseconds()) + }) + + // failure retry will success, but cost is more than mixed retry + t.Run("failure retry", func(t *testing.T) { + rc := NewRetryContainer() + fp := NewFailurePolicy() + fp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildFailurePolicy(fp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(err, kerrors.ErrRPCTimeout)) + test.Assert(t, ret.getCallTimes() == 3, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-300.0) < 50.0, cost.Milliseconds()) + }) + + // backup request will failure + t.Run("backup request", func(t *testing.T) { + rc := NewRetryContainer() + bp := NewBackupPolicy(100) + bp.WithMaxRetryTimes(2) // max call times is 3 + p := BuildBackupRequest(bp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTimeout(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(err, kerrors.ErrRPCTimeout)) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-100.0) < 50.0, cost.Milliseconds()) + }) +} + +// Assuming BizStatus=11111/11112 needs to be retried, +// +// the first reply is BizStatus=11111, it costs 250ms, +// the second reply is BizStatus=11112, it costs 250ms, +// the third reply is BizStatus=0, it costs 250ms, +// +// Configuration: MaxRetryTimes=3 BackupDelay=100ms +// - Mixed Retry: Success, cost 450ms, two backup retry, and one failure retry +// - Failure Retry: Success, cost 750ms +// - Backup Retry: Biz Error, cost 250ms +func TestMockCase4WithDiffRetry(t *testing.T) { + bizStatusCode0, bizStatusCode1, bizStatusCode2 := 0, 11111, 11112 + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + retryWithResp := func(ri rpcinfo.RPCInfo, callTimes int32, resp *mockResult) RPCCallFunc { + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + ct := atomic.AddInt32(&callTimes, 1) + resp.setCallTimes(ct) + time.Sleep(250 * time.Millisecond) + switch ct { + case 1: + resp.setResult(mockResp{code: bizStatusCode1}) + return ri, resp, nil + case 2: + resp.setResult(mockResp{code: bizStatusCode2}) + return ri, resp, nil + case 3: + resp.setResult(mockResp{code: bizStatusCode0}) + return ri, resp, nil + } + return ri, nil, errors.New("mock error") + } + } + resultRetry := &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + bizCode := resp.(*mockResult).getResult().(mockResp).code + if bizCode == bizStatusCode1 || bizCode == bizStatusCode2 { + return true + } + return false + }} + // mixed retry will success, cost is least + t.Run("mixed retry", func(t *testing.T) { + rc := NewRetryContainer() + mp := NewMixedPolicy(100) + mp.WithMaxRetryTimes(3) // max call times is 4 + mp.WithSpecifiedResultRetry(resultRetry) + p := BuildMixedPolicy(mp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithResp(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err == nil, err) + test.Assert(t, ret.getCallTimes() == 4, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-450.0) < 50.0, cost.Milliseconds()) + }) + + // failure retry will success, but cost is more than mixed retry + t.Run("failure retry", func(t *testing.T) { + rc := NewRetryContainer() + fp := NewFailurePolicy() + fp.WithMaxRetryTimes(3) // max call times is 4 + fp.WithSpecifiedResultRetry(resultRetry) + p := BuildFailurePolicy(fp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithResp(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err == nil, err) + test.Assert(t, ret.getCallTimes() == 3, ret.callTimes) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-750.0) < 50.0, cost.Milliseconds()) + }) + + // backup request will failure + t.Run("backup request", func(t *testing.T) { + rc := NewRetryContainer() + bp := NewBackupPolicy(100) + bp.WithMaxRetryTimes(2) // backup max retry times is 2 + p := BuildBackupRequest(bp) + ri = genRPCInfo() + ret := &mockResult{} + start := time.Now() + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithResp(ri, 0, ret), ri, nil) + cost := time.Since(start) + test.Assert(t, err == nil, err) + test.Assert(t, ret.getResult().(mockResp).code == bizStatusCode1) + test.Assert(t, !ok) + test.Assert(t, math.Abs(float64(cost.Milliseconds())-250.0) < 50.0, cost.Milliseconds()) + }) +} + +func BenchmarkMixedRetry(b *testing.B) { + bizStatusCode0, bizStatusCode1, bizStatusCode2 := 0, 11111, 11112 + ri := genRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + retryWithResp := func(ri rpcinfo.RPCInfo, callTimes int32, resp *mockResult) RPCCallFunc { + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + ct := atomic.AddInt32(&callTimes, 1) + resp.setCallTimes(ct) + switch ct { + case 1: + resp.setResult(mockResp{code: bizStatusCode1}) + return ri, resp, nil + case 2: + resp.setResult(mockResp{code: bizStatusCode2}) + return ri, resp, nil + case 3: + resp.setResult(mockResp{code: bizStatusCode0}) + return ri, resp, nil + } + return ri, nil, errors.New("mock error") + } + } + resultRetry := &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { + bizCode := resp.(*mockResult).getResult().(mockResp).code + if bizCode == bizStatusCode1 || bizCode == bizStatusCode2 { + return true + } + return false + }} + rc := NewRetryContainer() + mp := NewMixedPolicy(100) + mp.WithMaxRetryTimes(3) // max call times is 4 + mp.WithSpecifiedResultRetry(resultRetry) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + p := BuildMixedPolicy(mp) + ri = genRPCInfo() + ret := &mockResult{} + _, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithResp(ri, 0, ret), ri, nil) + test.Assert(b, err == nil, err) + test.Assert(b, !ok) + } + }) +} diff --git a/pkg/retry/policy.go b/pkg/retry/policy.go index c3d2a06e77..2161739583 100644 --- a/pkg/retry/policy.go +++ b/pkg/retry/policy.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/rpcinfo" ) @@ -30,6 +31,7 @@ type Type int const ( FailureType Type = iota BackupType + MixedType ) // String prints human readable information. @@ -39,6 +41,8 @@ func (t Type) String() string { return "Failure" case BackupType: return "Backup" + case MixedType: + return "Mixed" } return "" } @@ -59,6 +63,14 @@ func BuildBackupRequest(p *BackupPolicy) Policy { return Policy{Enable: true, Type: BackupType, BackupPolicy: p} } +// BuildMixedPolicy is used to build Policy with *MixedPolicy +func BuildMixedPolicy(p *MixedPolicy) Policy { + if p == nil { + return Policy{} + } + return Policy{Enable: true, Type: MixedType, MixedPolicy: p} +} + // Policy contains all retry policies // DON'T FORGET to update Equals() and DeepCopy() if you add new fields type Policy struct { @@ -68,6 +80,7 @@ type Policy struct { // notice: only one retry policy can be enabled, which one depend on Policy.Type FailurePolicy *FailurePolicy `json:"failure_policy,omitempty"` BackupPolicy *BackupPolicy `json:"backup_policy,omitempty"` + MixedPolicy *MixedPolicy `json:"mixed_policy,omitempty"` } func (p *Policy) DeepCopy() *Policy { @@ -79,6 +92,7 @@ func (p *Policy) DeepCopy() *Policy { Type: p.Type, FailurePolicy: p.FailurePolicy.DeepCopy(), BackupPolicy: p.BackupPolicy.DeepCopy(), + MixedPolicy: p.MixedPolicy.DeepCopy(), } } @@ -104,6 +118,13 @@ type BackupPolicy struct { RetrySameNode bool `json:"retry_same_node"` } +// MixedPolicy for failure retry +// DON'T FORGET to update Equals() and DeepCopy() if you add new fields +type MixedPolicy struct { + RetryDelayMS uint32 `json:"retry_delay_ms"` + FailurePolicy +} + // StopPolicy is a group policies to decide when stop retry type StopPolicy struct { MaxRetryTimes int `json:"max_retry_times"` @@ -185,135 +206,6 @@ func (p Policy) Equals(np Policy) bool { return true } -// Equals to check if FailurePolicy is equal -func (p *FailurePolicy) Equals(np *FailurePolicy) bool { - if p == nil { - return np == nil - } - if np == nil { - return false - } - if p.StopPolicy != np.StopPolicy { - return false - } - if !p.BackOffPolicy.Equals(np.BackOffPolicy) { - return false - } - if p.RetrySameNode != np.RetrySameNode { - return false - } - if p.Extra != np.Extra { - return false - } - // don't need to check `ShouldResultRetry`, ShouldResultRetry is only setup by option - // in remote config case will always return false if check it - return true -} - -func (p *FailurePolicy) DeepCopy() *FailurePolicy { - if p == nil { - return nil - } - return &FailurePolicy{ - StopPolicy: p.StopPolicy, - BackOffPolicy: p.BackOffPolicy.DeepCopy(), - RetrySameNode: p.RetrySameNode, - ShouldResultRetry: p.ShouldResultRetry, // don't need DeepCopy - Extra: p.Extra, - } -} - -// IsRespRetryWithCtxNonNil is used to check if RespRetryWithCtx is nil. -func (p *FailurePolicy) IsRespRetryWithCtxNonNil() bool { - return p.ShouldResultRetry != nil && p.ShouldResultRetry.RespRetryWithCtx != nil -} - -// IsErrorRetryWithCtxNonNil is used to check if ErrorRetryWithCtx is nil -func (p *FailurePolicy) IsErrorRetryWithCtxNonNil() bool { - return p.ShouldResultRetry != nil && p.ShouldResultRetry.ErrorRetryWithCtx != nil -} - -// IsRespRetryNonNil is used to check if RespRetry is nil. -// Deprecated: please use IsRespRetryWithCtxNonNil instead of IsRespRetryNonNil. -func (p *FailurePolicy) IsRespRetryNonNil() bool { - return p.ShouldResultRetry != nil && p.ShouldResultRetry.RespRetry != nil -} - -// IsErrorRetryNonNil is used to check if ErrorRetry is nil. -// Deprecated: please use IsErrorRetryWithCtxNonNil instead of IsErrorRetryNonNil. -func (p *FailurePolicy) IsErrorRetryNonNil() bool { - return p.ShouldResultRetry != nil && p.ShouldResultRetry.ErrorRetry != nil -} - -// IsRetryForTimeout is used to check if timeout error need to retry -func (p *FailurePolicy) IsRetryForTimeout() bool { - return p.ShouldResultRetry == nil || !p.ShouldResultRetry.NotRetryForTimeout -} - -// IsRespRetry is used to check if the resp need to do retry. -func (p *FailurePolicy) IsRespRetry(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { - // note: actually, it is better to check IsRespRetry to ignore the bad cases, - // but IsRespRetry is a deprecated definition and here will be executed for every call, depends on ConvertResultRetry to ensure the compatibility - return p.IsRespRetryWithCtxNonNil() && p.ShouldResultRetry.RespRetryWithCtx(ctx, resp, ri) -} - -// IsErrorRetry is used to check if the error need to do retry. -func (p *FailurePolicy) IsErrorRetry(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - // note: actually, it is better to check IsErrorRetry to ignore the bad cases, - // but IsErrorRetry is a deprecated definition and here will be executed for every call, depends on ConvertResultRetry to ensure the compatibility - return p.IsErrorRetryWithCtxNonNil() && p.ShouldResultRetry.ErrorRetryWithCtx(ctx, err, ri) -} - -// ConvertResultRetry is used to convert 'ErrorRetry and RespRetry' to 'ErrorRetryWithCtx and RespRetryWithCtx' -func (p *FailurePolicy) ConvertResultRetry() { - if p == nil || p.ShouldResultRetry == nil { - return - } - rr := p.ShouldResultRetry - if rr.ErrorRetry != nil && rr.ErrorRetryWithCtx == nil { - rr.ErrorRetryWithCtx = func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - return rr.ErrorRetry(err, ri) - } - } - if rr.RespRetry != nil && rr.RespRetryWithCtx == nil { - rr.RespRetryWithCtx = func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { - return rr.RespRetry(resp, ri) - } - } -} - -// Equals to check if BackupPolicy is equal -func (p *BackupPolicy) Equals(np *BackupPolicy) bool { - if p == nil { - return np == nil - } - if np == nil { - return false - } - if p.RetryDelayMS != np.RetryDelayMS { - return false - } - if p.StopPolicy != np.StopPolicy { - return false - } - if p.RetrySameNode != np.RetrySameNode { - return false - } - - return true -} - -func (p *BackupPolicy) DeepCopy() *BackupPolicy { - if p == nil { - return nil - } - return &BackupPolicy{ - RetryDelayMS: p.RetryDelayMS, - StopPolicy: p.StopPolicy, // not a pointer, will copy the value here - RetrySameNode: p.RetrySameNode, - } -} - // Equals to check if BackOffPolicy is equal. func (p *BackOffPolicy) Equals(np *BackOffPolicy) bool { if p == nil { @@ -368,3 +260,16 @@ func checkCBErrorRate(p *CBPolicy) error { } return nil } + +func checkStopPolicy(sp *StopPolicy, maxRetryTimes int, retryer Retryer) error { + rt := sp.MaxRetryTimes + // 0 is valid, it means stop retry + if rt < 0 || rt > maxRetryTimes { + return fmt.Errorf("invalid MaxRetryTimes[%d]", rt) + } + if e := checkCBErrorRate(&sp.CBPolicy); e != nil { + sp.CBPolicy.ErrorRate = defaultCBErrRate + klog.Warnf("KITEX: %s retry - %s, use default %0.2f", retryer.Type(), e.Error(), defaultCBErrRate) + } + return nil +} diff --git a/pkg/retry/policy_test.go b/pkg/retry/policy_test.go index 156f877b29..fd9211ec10 100644 --- a/pkg/retry/policy_test.go +++ b/pkg/retry/policy_test.go @@ -21,7 +21,7 @@ import ( "reflect" "testing" - jsoniter "github.com/json-iterator/go" + "github.com/bytedance/sonic" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -29,8 +29,7 @@ import ( "github.com/cloudwego/kitex/pkg/stats" ) -var ( - jsoni = jsoniter.ConfigCompatibleWithStandardLibrary +const ( method = "test" ) @@ -40,11 +39,11 @@ func TestFailureRetryPolicy(t *testing.T) { // case 1 fp.WithMaxRetryTimes(3) - jsonRet, err := jsoni.MarshalToString(fp) + jsonRet, err := sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp2 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp2) + err = sonic.UnmarshalString(jsonRet, &fp2) test.Assert(t, err == nil, err) test.Assert(t, fp.Equals(&fp2)) @@ -52,23 +51,23 @@ func TestFailureRetryPolicy(t *testing.T) { fp.WithMaxRetryTimes(2) fp.WithRetrySameNode() fp.WithFixedBackOff(10) - jsonRet, err = jsoni.MarshalToString(fp) + jsonRet, err = sonic.MarshalString(fp) test.Assert(t, err == nil, err) // case 3 var fp3 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp3) + err = sonic.UnmarshalString(jsonRet, &fp3) test.Assert(t, err == nil, err) test.Assert(t, fp.Equals(&fp3), fp3) // case 4 fp.WithRetrySameNode() fp.WithRandomBackOff(10, 20) - jsonRet, err = jsoni.MarshalToString(fp) + jsonRet, err = sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp4 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp4) + err = sonic.UnmarshalString(jsonRet, &fp4) test.Assert(t, err == nil, err) test.Assert(t, fp.Equals(&fp4), fp4) @@ -76,11 +75,11 @@ func TestFailureRetryPolicy(t *testing.T) { fp.WithRetryBreaker(0.1) fp.WithDDLStop() fp.WithMaxDurationMS(100) - jsonRet, err = jsoni.MarshalToString(fp) + jsonRet, err = sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp5 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp5) + err = sonic.UnmarshalString(jsonRet, &fp5) test.Assert(t, err == nil, err) test.Assert(t, fp.Equals(&fp5), fp5) @@ -95,20 +94,20 @@ func TestFailureRetryPolicy(t *testing.T) { }, Extra: "{}", } - jsonRet, err = jsoni.MarshalToString(fp) + jsonRet, err = sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp6 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp6) + err = sonic.UnmarshalString(jsonRet, &fp6) test.Assert(t, err == nil, err) test.Assert(t, fp6.BackOffPolicy == nil) test.Assert(t, fp.Equals(&fp6), fp6) // case 7 fp.DisableChainRetryStop() - jsonRet, err = jsoni.MarshalToString(fp) + jsonRet, err = sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp7 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp7) + err = sonic.UnmarshalString(jsonRet, &fp7) test.Assert(t, err == nil, err) test.Assert(t, fp7.BackOffPolicy == nil) test.Assert(t, fp.Equals(&fp7), fp7) @@ -121,12 +120,13 @@ func TestFailureRetryPolicy(t *testing.T) { fp.WithSpecifiedResultRetry(&ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { return false }}) + fp.convertResultRetry() test.Assert(t, fp.String() == "{StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:true "+ "DDLStop:false CBPolicy:{ErrorRate:0.1}} BackOffPolicy: RetrySameNode:false ShouldResultRetry:{ErrorRetry:true, RespRetry:false}}", fp) - jsonRet, err = jsoni.MarshalToString(fp) + jsonRet, err = sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp9 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp9) + err = sonic.UnmarshalString(jsonRet, &fp9) test.Assert(t, err == nil, err) test.Assert(t, fp.Equals(&fp9), fp9) test.Assert(t, fp9.ShouldResultRetry == nil) @@ -139,13 +139,14 @@ func TestFailureRetryPolicyWithResultRetry(t *testing.T) { }, ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { return false }}) + fp.convertResultRetry() test.Assert(t, fp.String() == "{StopPolicy:{MaxRetryTimes:2 MaxDurationMS:0 DisableChainStop:false DDLStop:false "+ "CBPolicy:{ErrorRate:0.1}} BackOffPolicy:&{BackOffType:none CfgItems:map[]} RetrySameNode:false ShouldResultRetry:{ErrorRetry:true, RespRetry:true}}", fp) - jsonRet, err := jsoni.MarshalToString(fp) + jsonRet, err := sonic.MarshalString(fp) test.Assert(t, err == nil, err) var fp10 FailurePolicy - err = jsoni.UnmarshalFromString(jsonRet, &fp10) + err = sonic.UnmarshalString(jsonRet, &fp10) test.Assert(t, err == nil, err) test.Assert(t, fp.Equals(&fp10), fp10) test.Assert(t, fp10.ShouldResultRetry == nil) @@ -167,21 +168,21 @@ func TestBackupRequest(t *testing.T) { // case 1 bp.WithMaxRetryTimes(2) - jsonRet, err := jsoni.MarshalToString(bp) + jsonRet, err := sonic.MarshalString(bp) test.Assert(t, err == nil, err) var bp2 BackupPolicy - err = jsoni.UnmarshalFromString(jsonRet, &bp2) + err = sonic.UnmarshalString(jsonRet, &bp2) test.Assert(t, err == nil, err) test.Assert(t, bp.Equals(&bp2)) // case 2 bp.DisableChainRetryStop() - jsonRet, err = jsoni.MarshalToString(bp) + jsonRet, err = sonic.MarshalString(bp) test.Assert(t, err == nil, err) var bp3 BackupPolicy - err = jsoni.UnmarshalFromString(jsonRet, &bp3) + err = sonic.UnmarshalString(jsonRet, &bp3) test.Assert(t, err == nil, err) test.Assert(t, bp.Equals(&bp3)) } @@ -194,11 +195,11 @@ func TestRetryPolicyBothNotNil(t *testing.T) { BackupPolicy: NewBackupPolicy(20), } ctx := context.Background() - jsonRet, err := jsoni.MarshalToString(p) + jsonRet, err := sonic.MarshalString(p) test.Assert(t, err == nil, err) var p2 Policy - err = jsoni.UnmarshalFromString(jsonRet, &p2) + err = sonic.UnmarshalString(jsonRet, &p2) test.Assert(t, err == nil, err) test.Assert(t, p2.Enable == true) test.Assert(t, p.Equals(p2)) @@ -221,11 +222,11 @@ func TestRetryPolicyBothNotNil(t *testing.T) { // test new policy both nil func TestRetryPolicyBothNil(t *testing.T) { p := Policy{} - jsonRet, err := jsoni.MarshalToString(p) + jsonRet, err := sonic.MarshalString(p) test.Assert(t, err == nil, err) var p2 Policy - err = jsoni.UnmarshalFromString(jsonRet, &p2) + err = sonic.UnmarshalString(jsonRet, &p2) test.Assert(t, err == nil, err) test.Assert(t, p.Equals(p2)) @@ -245,7 +246,7 @@ func TestRetryPolicyFailure(t *testing.T) { } jsonRet := `{"enable":true,"type":0,"failure_policy":{"stop_policy":{"max_retry_times":2,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1,"min_sample":200}},"backoff_policy":{"backoff_type":"none"},"retry_same_node":false}}` var p2 Policy - err := jsoni.UnmarshalFromString(jsonRet, &p2) + err := sonic.UnmarshalString(jsonRet, &p2) test.Assert(t, err == nil, err) test.Assert(t, p2.Enable) test.Assert(t, p.Equals(p2)) @@ -269,11 +270,11 @@ func TestRetryPolicyFailure(t *testing.T) { Enable: true, FailurePolicy: fp, } - jsonRet, err = jsoni.MarshalToString(p) + jsonRet, err = sonic.MarshalString(p) test.Assert(t, err == nil, err) var p3 Policy - err = jsoni.UnmarshalFromString(jsonRet, &p3) + err = sonic.UnmarshalString(jsonRet, &p3) test.Assert(t, err == nil, err) test.Assert(t, p.Equals(p3)) @@ -312,62 +313,62 @@ func TestPolicyNotEqual(t *testing.T) { RetrySameNode: false, }, } - jsonRet, err := jsoni.MarshalToString(policy) + jsonRet, err := sonic.MarshalString(policy) test.Assert(t, err == nil, err) // case1 enable not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.Enable = false test.Assert(t, !p.Equals(policy)) // case2 type not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.Type = BackupType test.Assert(t, !p.Equals(policy)) // case3 failurePolicy not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy = nil test.Assert(t, !p.Equals(policy)) test.Assert(t, !policy.Equals(p)) // case4 failurePolicy stopPolicy not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy.StopPolicy.MaxRetryTimes = 2 test.Assert(t, !p.Equals(policy)) // case5 failurePolicy backOffPolicy not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy.BackOffPolicy = nil test.Assert(t, !p.Equals(policy)) test.Assert(t, !policy.Equals(p)) // case6 failurePolicy backOffPolicy backOffType not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy.BackOffPolicy.BackOffType = RandomBackOffType test.Assert(t, !p.Equals(policy)) // case7 failurePolicy backOffPolicy len(cfgItems) not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy.BackOffPolicy.CfgItems[MinMSBackOffCfgKey] = 100 test.Assert(t, !p.Equals(policy)) // case8 failurePolicy backOffPolicy cfgItems not equal p = Policy{} - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy.BackOffPolicy.CfgItems[FixMSBackOffCfgKey] = 101 test.Assert(t, !p.Equals(policy)) // case9 failurePolicy retrySameNode not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.FailurePolicy.RetrySameNode = true test.Assert(t, !p.Equals(policy)) @@ -390,31 +391,31 @@ func TestPolicyNotEqual(t *testing.T) { RetrySameNode: false, }, } - jsonRet, err = jsoni.MarshalToString(policy) + jsonRet, err = sonic.MarshalString(policy) test.Assert(t, err == nil, err) // case10 backupPolicy not equal p = Policy{} - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.BackupPolicy = nil test.Assert(t, !p.Equals(policy)) test.Assert(t, !policy.Equals(p)) // case11 backupPolicy retryDelayMS not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.BackupPolicy.RetryDelayMS = 2 test.Assert(t, !p.Equals(policy)) // case12 backupPolicy stopPolicy not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.BackupPolicy.StopPolicy.MaxRetryTimes = 3 test.Assert(t, !p.Equals(policy)) // case13 backupPolicy retrySameNode not equal - err = jsoni.UnmarshalFromString(jsonRet, &p) + err = sonic.UnmarshalString(jsonRet, &p) test.Assert(t, err == nil, err) p.BackupPolicy.RetrySameNode = true test.Assert(t, !p.Equals(policy)) @@ -432,7 +433,7 @@ func TestPolicyNotRetryForTimeout(t *testing.T) { }, }} // case 1: ShouldResultRetry is nil, retry for timeout - test.Assert(t, fp.IsRetryForTimeout()) + test.Assert(t, fp.isRetryForTimeout()) // case 2: ShouldResultRetry is not nil, NotRetryForTimeout is false, retry for timeout fp.ShouldResultRetry = &ShouldResultRetry{ @@ -442,7 +443,7 @@ func TestPolicyNotRetryForTimeout(t *testing.T) { // case 3: ShouldResultRetry is not nil, NotRetryForTimeout is true, not retry for timeout fp.ShouldResultRetry.NotRetryForTimeout = true - test.Assert(t, !fp.IsRetryForTimeout()) + test.Assert(t, !fp.isRetryForTimeout()) } func genRPCInfo() rpcinfo.RPCInfo { @@ -747,3 +748,30 @@ func TestPolicy_DeepCopy(t *testing.T) { }) } } + +func TestCheckStopPolicy(t *testing.T) { + mp := NewMixedPolicy(100) + err := checkStopPolicy(&mp.StopPolicy, maxMixRetryTimes, &mixedRetryer{}) + test.Assert(t, err == nil, err) + + mp.StopPolicy.MaxRetryTimes = -1 + err = checkStopPolicy(&mp.StopPolicy, maxMixRetryTimes, &mixedRetryer{}) + test.Assert(t, err != nil, err) + test.Assert(t, err.Error() == "invalid MaxRetryTimes[-1]") + + mp.StopPolicy.MaxRetryTimes = 5 + err = checkStopPolicy(&mp.StopPolicy, maxMixRetryTimes, &mixedRetryer{}) + test.Assert(t, err != nil, err) + test.Assert(t, err.Error() == "invalid MaxRetryTimes[5]") + mp.StopPolicy.MaxRetryTimes = maxMixRetryTimes + + mp.StopPolicy.CBPolicy.ErrorRate = 0.5 + err = checkStopPolicy(&mp.StopPolicy, maxMixRetryTimes, &mixedRetryer{}) + test.Assert(t, err == nil, err) + test.Assert(t, mp.StopPolicy.CBPolicy.ErrorRate == defaultCBErrRate) + + mp.StopPolicy.CBPolicy.ErrorRate = -0.1 + err = checkStopPolicy(&mp.StopPolicy, maxMixRetryTimes, &mixedRetryer{}) + test.Assert(t, err == nil, err) + test.Assert(t, mp.StopPolicy.CBPolicy.ErrorRate == defaultCBErrRate) +} diff --git a/pkg/retry/retryer.go b/pkg/retry/retryer.go index 1f0b8b1507..6b485a0ffd 100644 --- a/pkg/retry/retryer.go +++ b/pkg/retry/retryer.go @@ -43,10 +43,6 @@ type Retryer interface { // If not satisfy won't execute Retryer.Do and return the reason message // Execute anyway for the first time regardless of able to retry. AllowRetry(ctx context.Context) (msg string, ok bool) - - // ShouldRetry to check if retry request can be called, it is checked in retryer.Do. - // If not satisfy will return the reason message - ShouldRetry(ctx context.Context, err error, callTimes int, req interface{}, cbKey string) (msg string, ok bool) UpdatePolicy(policy Policy) error // Retry policy execute func. recycleRI is to decide if the firstRI can be recycled. @@ -369,9 +365,12 @@ func (rc *Container) WithRetryIfNeeded(ctx context.Context, callOptRetry *Policy // NewRetryer build a retryer with policy func NewRetryer(p Policy, r *ShouldResultRetry, cbC *cbContainer) (retryer Retryer, err error) { // just one retry policy can be enabled at same time - if p.Type == BackupType { + switch p.Type { + case MixedType: + retryer, err = newMixedRetryer(p, r, cbC) + case BackupType: retryer, err = newBackupRetryer(p, cbC) - } else { + default: retryer, err = newFailureRetryer(p, r, cbC) } return @@ -437,8 +436,11 @@ func (rc *Container) updateRetryer(rr *ShouldResultRetry) { rc.shouldResultRetry = rr if rc.shouldResultRetry != nil { rc.retryerMap.Range(func(key, value interface{}) bool { - if fr, ok := value.(*failureRetryer); ok { - fr.setSpecifiedResultRetryIfNeeded(rc.shouldResultRetry) + switch r := value.(type) { + case *failureRetryer: + r.setSpecifiedResultRetryIfNeeded(rc.shouldResultRetry, r.policy) + case *mixedRetryer: + r.setSpecifiedResultRetryIfNeeded(rc.shouldResultRetry, &r.policy.FailurePolicy) } return true }) diff --git a/pkg/retry/retryer_test.go b/pkg/retry/retryer_test.go index 0a13991b68..3d028c12d8 100644 --- a/pkg/retry/retryer_test.go +++ b/pkg/retry/retryer_test.go @@ -24,6 +24,8 @@ import ( "testing" "time" + "github.com/bytedance/sonic" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/circuitbreak" "github.com/cloudwego/kitex/pkg/discovery" @@ -38,6 +40,7 @@ var ( remoteTagKey = "k" remoteTagValue = "v" remoteTags = map[string]string{remoteTagKey: remoteTagValue} + sonici = sonic.Config{SortMapKeys: true}.Froze() ) // test new retry container @@ -112,8 +115,8 @@ func TestNewRetryContainer(t *testing.T) { RetryDelayMS: 0, }, }) - msg = "new retryer[test-Backup] failed, err=newBackupRetryer failed, err=invalid backup request delay duration or retryTimes, at " - test.Assert(t, rc.msg[:len(msg)] == msg) + msg = "new retryer[test-Backup] failed, err=newBackupRetryer failed, err=invalid retry delay duration in backupRetryer, at " + test.Assert(t, rc.msg[:len(msg)] == msg, rc.msg) // backupPolicy cBPolicy config invalid rc.NotifyPolicyChange(method, Policy{ @@ -142,8 +145,8 @@ func TestNewRetryContainer(t *testing.T) { }, }, }) - msg = "new retryer[test-Failure] failed, err=newfailureRetryer failed, err=invalid failure MaxRetryTimes[6], at " - test.Assert(t, rc.msg[:len(msg)] == msg) + msg = "new retryer[test-Failure] failed, err=newfailureRetryer failed, err=invalid MaxRetryTimes[6], at " + test.Assert(t, rc.msg[:len(msg)] == msg, rc.msg) // failurePolicy cBPolicy config invalid rc = NewRetryContainer() @@ -278,591 +281,121 @@ func TestNewRetryContainer(t *testing.T) { // test container dump func TestContainer_Dump(t *testing.T) { // test backupPolicy dump - rc := NewRetryContainerWithCB(nil, nil) - methodPolicies := map[string]Policy{ - method: { - Enable: true, - Type: 1, - BackupPolicy: NewBackupPolicy(20), - }, - } - rc.InitWithPolicies(methodPolicies) - err := rc.Init(map[string]Policy{Wildcard: { - Enable: true, - Type: 1, - BackupPolicy: NewBackupPolicy(20), - }}, nil) - test.Assert(t, err == nil, err) - rcDump, ok := rc.Dump().(map[string]interface{}) - test.Assert(t, ok) - hasCodeCfg, err := jsoni.MarshalToString(rcDump["has_code_cfg"]) - test.Assert(t, err == nil, err) - test.Assert(t, hasCodeCfg == "true", hasCodeCfg) - testStr, err := jsoni.MarshalToString(rcDump["test"]) - msg := `{"backupRequest":{"retry_delay_ms":20,"stop_policy":{"max_retry_times":1,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"retry_same_node":false},"enable":true}` - test.Assert(t, err == nil, err) - test.Assert(t, testStr == msg) - - // test failurePolicy dump - rc = NewRetryContainerWithCB(nil, nil) - methodPolicies = map[string]Policy{ - method: { - Enable: true, - Type: FailureType, - FailurePolicy: NewFailurePolicy(), - }, - } - rc.InitWithPolicies(methodPolicies) - err = rc.Init(map[string]Policy{Wildcard: { - Enable: true, - Type: FailureType, - FailurePolicy: NewFailurePolicy(), - }}, nil) - test.Assert(t, err == nil, err) - rcDump, ok = rc.Dump().(map[string]interface{}) - test.Assert(t, ok) - hasCodeCfg, err = jsoni.MarshalToString(rcDump["has_code_cfg"]) - test.Assert(t, err == nil, err) - test.Assert(t, hasCodeCfg == "true") - testStr, err = jsoni.MarshalToString(rcDump["test"]) - msg = `{"enable":true,"failure_retry":{"stop_policy":{"max_retry_times":2,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"backoff_policy":{"backoff_type":"none"},"retry_same_node":false,"extra":""},"specified_result_retry":{"error_retry":false,"old_error_retry":false,"old_resp_retry":false,"resp_retry":false}}` - test.Assert(t, err == nil, err) - test.Assert(t, testStr == msg, testStr) -} - -// test FailurePolicy call -func TestFailurePolicyCall(t *testing.T) { - // call while rpc timeout - ctx := context.Background() - rc := NewRetryContainer() - failurePolicy := NewFailurePolicy() - failurePolicy.BackOffPolicy.BackOffType = FixedBackOffType - failurePolicy.BackOffPolicy.CfgItems = map[BackOffCfgKey]float64{ - FixMSBackOffCfgKey: 100.0, - } - failurePolicy.StopPolicy.MaxDurationMS = 100 - err := rc.Init(map[string]Policy{Wildcard: { - Enable: true, - Type: 0, - FailurePolicy: failurePolicy, - }}, nil) - test.Assert(t, err == nil, err) - ri := genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - return ri, nil, kerrors.ErrRPCTimeout - }, ri, nil) - test.Assert(t, err != nil, err) - test.Assert(t, !ok) - - // call normal - failurePolicy.StopPolicy.MaxDurationMS = 0 - err = rc.Init(map[string]Policy{Wildcard: { - Enable: true, - Type: 0, - FailurePolicy: failurePolicy, - }}, nil) - test.Assert(t, err == nil, err) - _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - return ri, nil, nil - }, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, ok) -} - -// test retry with one time policy -func TestRetryWithOneTimePolicy(t *testing.T) { - // call while rpc timeout and exceed MaxDurationMS cause BackOffPolicy is wait fix 100ms, it is invalid config - failurePolicy := NewFailurePolicy() - failurePolicy.BackOffPolicy.BackOffType = FixedBackOffType - failurePolicy.BackOffPolicy.CfgItems = map[BackOffCfgKey]float64{ - FixMSBackOffCfgKey: 100.0, - } - failurePolicy.StopPolicy.MaxDurationMS = 100 - p := Policy{ - Enable: true, - Type: 0, - FailurePolicy: failurePolicy, - } - ri := genRPCInfo() - ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - _, ok, err := NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - return ri, nil, kerrors.ErrRPCTimeout - }, ri, nil) - test.Assert(t, err != nil, err) - test.Assert(t, !ok) - - // call no MaxDurationMS limit, the retry will success - failurePolicy.StopPolicy.MaxDurationMS = 0 - var callTimes int32 - ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo()) - _, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - if atomic.LoadInt32(&callTimes) == 0 { - atomic.AddInt32(&callTimes, 1) - return ri, nil, kerrors.ErrRPCTimeout - } - return ri, nil, nil - }, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, !ok) - - // call backup request - p = BuildBackupRequest(NewBackupPolicy(10)) - ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo()) - callTimes = 0 - _, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - if atomic.LoadInt32(&callTimes) == 0 || atomic.LoadInt32(&callTimes) == 1 { - atomic.AddInt32(&callTimes, 1) - time.Sleep(time.Millisecond * 100) - } - return ri, nil, nil - }, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, !ok) - test.Assert(t, atomic.LoadInt32(&callTimes) == 2) -} - -// test specified error to retry -func TestSpecifiedErrorRetry(t *testing.T) { - retryWithTransError := func(callTimes int32) RPCCallFunc { - // fails for the first call if callTimes is initialized to 0 - return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - newVal := atomic.AddInt32(&callTimes, 1) - if newVal == 1 { - return genRPCInfo(), nil, remote.NewTransErrorWithMsg(1000, "mock") - } else { - return genRPCInfoWithRemoteTag(remoteTags), nil, nil - } + t.Run("backupPolicy dump", func(t *testing.T) { + rc := NewRetryContainerWithCB(nil, nil) + methodPolicies := map[string]Policy{ + method: { + Enable: true, + Type: BackupType, + BackupPolicy: NewBackupPolicy(20), + }, } - } - ri := genRPCInfo() - ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - - // case1: specified method retry with error - t.Run("case1", func(t *testing.T) { - rc := NewRetryContainer() - shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) + err := rc.Init(methodPolicies, nil) test.Assert(t, err == nil, err) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) + rcDump, ok := rc.Dump().(map[string]interface{}) test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) - }) - - // case2: specified method retry with error, but use backup request config cannot be effective - t.Run("case2", func(t *testing.T) { - shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri = genRPCInfo() - _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err != nil, err) - test.Assert(t, !ok) - }) - - // case3: specified method retry with error, but method not match - t.Run("case3", func(t *testing.T) { - shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() != method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + hasCodeCfg := rcDump["has_code_cfg"].(bool) + test.Assert(t, hasCodeCfg) + testStr, err := sonici.MarshalToString(rcDump[method]) + msg := `{"backup_request":{"retry_delay_ms":20,"stop_policy":{"max_retry_times":1,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"retry_same_node":false},"enable":true}` test.Assert(t, err == nil, err) - ri = genRPCInfo() - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err != nil) - test.Assert(t, !ok) - _, ok = ri.To().Tag(remoteTagKey) - test.Assert(t, !ok) + test.Assert(t, testStr == msg) }) - // case4: all error retry - t.Run("case4", func(t *testing.T) { - rc := NewRetryContainer() - p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry())) - ri = genRPCInfo() - ri, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTransError(0), ri, nil) + // test backupPolicy dump + t.Run("backupPolicy dump without code_cfg", func(t *testing.T) { + rc := NewRetryContainerWithCB(nil, nil) + policy := Policy{ + Enable: true, + Type: 1, + BackupPolicy: NewBackupPolicy(20), + } + err := rc.Init(nil, nil) + rc.NotifyPolicyChange(method, policy) test.Assert(t, err == nil, err) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) + rcDump, ok := rc.Dump().(map[string]interface{}) test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) + hasCodeCfg := rcDump["has_code_cfg"].(bool) + test.Assert(t, !hasCodeCfg) + testStr, err := sonici.MarshalToString(rcDump[method]) + msg := `{"backup_request":{"retry_delay_ms":20,"stop_policy":{"max_retry_times":1,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"retry_same_node":false},"enable":true}` + test.Assert(t, err == nil, err) + test.Assert(t, testStr == msg) }) -} - -// test specified resp to retry -func TestSpecifiedRespRetry(t *testing.T) { - retryResult := &mockResult{} - retryResp := mockResp{ - code: 500, - msg: "retry", - } - noRetryResp := mockResp{ - code: 0, - msg: "noretry", - } - var callTimes int32 - retryWithResp := func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - newVal := atomic.AddInt32(&callTimes, 1) - if newVal == 1 { - retryResult.SetResult(retryResp) - return genRPCInfo(), retryResult, nil - } else { - retryResult.SetResult(noRetryResp) - return genRPCInfoWithRemoteTag(remoteTags), retryResult, nil - } - } - ctx := context.Background() - ri := genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - rc := NewRetryContainer() - // case1: specified method retry with resp - shouldResultRetry := &ShouldResultRetry{RespRetry: func(resp interface{}, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp { - return true - } - } - return false - }} - err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == noRetryResp, retryResult) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) - test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) - - // case2 specified method retry with resp, but use backup request config cannot be effective - atomic.StoreInt32(&callTimes, 0) - rc = NewRetryContainer() - err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(100))}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri = genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == retryResp, retryResp) - test.Assert(t, !ok) - - // case3: specified method retry with resp, but method not match - atomic.StoreInt32(&callTimes, 0) - shouldResultRetry = &ShouldResultRetry{RespRetry: func(resp interface{}, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() != method { - if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp { - return true - } - } - return false - }} - rc = NewRetryContainer() - err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri = genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == retryResp, retryResult) - test.Assert(t, ok) - _, ok = ri.To().Tag(remoteTagKey) - test.Assert(t, !ok) -} -// test specified error to retry with ErrorRetryWithCtx -func TestSpecifiedErrorRetryWithCtx(t *testing.T) { - retryWithTransError := func(callTimes int32) RPCCallFunc { - // fails for the first call if callTimes is initialized to 0 - return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - newVal := atomic.AddInt32(&callTimes, 1) - if newVal == 1 { - return genRPCInfo(), nil, remote.NewTransErrorWithMsg(1000, "mock") - } else { - return genRPCInfoWithRemoteTag(remoteTags), nil, nil - } + // test failurePolicy dump + t.Run("failurePolicy dump", func(t *testing.T) { + rc := NewRetryContainerWithCB(nil, nil) + methodPolicies := map[string]Policy{ + method: { + Enable: true, + Type: FailureType, + FailurePolicy: NewFailurePolicy(), + }, } - } - ri := genRPCInfo() - ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - - // case1: specified method retry with error - t.Run("case1", func(t *testing.T) { - rc := NewRetryContainer() - shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) + err := rc.Init(methodPolicies, nil) test.Assert(t, err == nil, err) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) + rcDump, ok := rc.Dump().(map[string]interface{}) test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) - }) - - // case2: specified method retry with error, but use backup request config cannot be effective - t.Run("case2", func(t *testing.T) { - shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri = genRPCInfo() - _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err != nil, err) - test.Assert(t, !ok) - }) - - // case3: specified method retry with error, but method not match - t.Run("case3", func(t *testing.T) { - shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - return ri.To().Method() != method - }} - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + hasCodeCfg := rcDump["has_code_cfg"].(bool) + test.Assert(t, hasCodeCfg) + testStr, err := sonici.MarshalToString(rcDump[method]) + msg := `{"enable":true,"failure_retry":{"stop_policy":{"max_retry_times":2,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"backoff_policy":{"backoff_type":"none"},"retry_same_node":false,"extra":""},"specified_result_retry":{"error_retry":false,"old_error_retry":false,"old_resp_retry":false,"resp_retry":false}}` test.Assert(t, err == nil, err) - ri = genRPCInfo() - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err != nil) - test.Assert(t, !ok) - _, ok = ri.To().Tag(remoteTagKey) - test.Assert(t, !ok) + test.Assert(t, testStr == msg, testStr) }) - // case4: all error retry - t.Run("case4", func(t *testing.T) { - rc := NewRetryContainer() - p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry())) - ri = genRPCInfo() - ri, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTransError(0), ri, nil) + // test mixedPolicy dump + t.Run("mixedPolicy dump", func(t *testing.T) { + rc := NewRetryContainerWithCB(nil, nil) + policy := Policy{ + Enable: true, + Type: MixedType, + MixedPolicy: NewMixedPolicy(20), + } + err := rc.Init(nil, nil) test.Assert(t, err == nil, err) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) + rc.NotifyPolicyChange(method, policy) + rcDump, ok := rc.Dump().(map[string]interface{}) test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) + hasCodeCfg := rcDump["has_code_cfg"].(bool) + test.Assert(t, !hasCodeCfg) + testStr, err := sonici.MarshalToString(rcDump[method]) + msg := `{"enable":true,"mixed_retry":{"retry_delay_ms":20,"stop_policy":{"max_retry_times":1,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"backoff_policy":{"backoff_type":"none"},"retry_same_node":false,"extra":""},"specified_result_retry":{"error_retry":false,"old_error_retry":false,"old_resp_retry":false,"resp_retry":false}}` + test.Assert(t, err == nil, err) + test.Assert(t, testStr == msg, testStr) }) - // case5: specified method retry with error, only ctx has some info then retry - ctxKeyVal := "ctxKeyVal" - t.Run("case5", func(t *testing.T) { - shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method && ctx.Value(ctxKeyVal) == ctxKeyVal { + // test mixedPolicy dump + t.Run("mixedPolicy dump with customized retry", func(t *testing.T) { + rc := NewRetryContainerWithCB(nil, nil) + policy := Policy{ + Enable: true, + Type: MixedType, + MixedPolicy: NewMixedPolicy(20), + } + rr := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { + if ri.To().Method() == method { if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { return true } } return false }} - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + err := rc.Init(nil, rr) test.Assert(t, err == nil, err) - ri = genRPCInfo() - ctx = context.WithValue(ctx, ctxKeyVal, ctxKeyVal) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) - test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) - }) -} - -// test specified error to retry, but has both old and new policy, the new one will be effective -func TestSpecifiedErrorRetryHasOldAndNew(t *testing.T) { - retryWithTransError := func(callTimes int32) RPCCallFunc { - // fails for the first call if callTimes is initialized to 0 - return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - newVal := atomic.AddInt32(&callTimes, 1) - if newVal == 1 { - return genRPCInfo(), nil, remote.NewTransErrorWithMsg(1000, "mock") - } else { - return genRPCInfoWithRemoteTag(remoteTags), nil, nil - } - } - } - ri := genRPCInfo() - ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - - // case1: ErrorRetryWithCtx will retry, but ErrorRetry not retry, the expect result is do retry - t.Run("case1", func(t *testing.T) { - rc := NewRetryContainer() - shouldResultRetry := &ShouldResultRetry{ - ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - return true - }, - ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - return false - }, - } - err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) + rc.NotifyPolicyChange(method, policy) + rcDump, ok := rc.Dump().(map[string]interface{}) test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) - }) - - // case2: ErrorRetryWithCtx not retry, but ErrorRetry retry, the expect result is that not do retry - t.Run("case2", func(t *testing.T) { - shouldResultRetry := &ShouldResultRetry{ - ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - return false - }, - ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - return true - }, - } - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) + hasCodeCfg := rcDump["has_code_cfg"].(bool) + test.Assert(t, !hasCodeCfg) + testStr, err := sonici.MarshalToString(rcDump[method]) + msg := `{"enable":true,"mixed_retry":{"retry_delay_ms":20,"stop_policy":{"max_retry_times":1,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"backoff_policy":{"backoff_type":"none"},"retry_same_node":false,"extra":""},"specified_result_retry":{"error_retry":true,"old_error_retry":false,"old_resp_retry":false,"resp_retry":false}}` test.Assert(t, err == nil, err) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil) - test.Assert(t, err != nil) - test.Assert(t, !ok) - _, ok = ri.To().Tag(remoteTagKey) - test.Assert(t, !ok) + test.Assert(t, testStr == msg, testStr) }) } -// test specified resp to retry with ErrorRetryWithCtx -func TestSpecifiedRespRetryWithCtx(t *testing.T) { - retryResult := &mockResult{} - retryResp := mockResp{ - code: 500, - msg: "retry", - } - noRetryResp := mockResp{ - code: 0, - msg: "noretry", - } - var callTimes int32 - retryWithResp := func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - newVal := atomic.AddInt32(&callTimes, 1) - if newVal == 1 { - retryResult.SetResult(retryResp) - return genRPCInfo(), retryResult, nil - } else { - retryResult.SetResult(noRetryResp) - return genRPCInfoWithRemoteTag(remoteTags), retryResult, nil - } - } - ctx := context.Background() - ri := genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - rc := NewRetryContainer() - // case1: specified method retry with resp - shouldResultRetry := &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp { - return true - } - } - return false - }} - err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == noRetryResp, retryResult) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) - test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) - - // case2 specified method retry with resp, but use backup request config cannot be effective - atomic.StoreInt32(&callTimes, 0) - rc = NewRetryContainer() - err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(100))}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri = genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == retryResp, retryResp) - test.Assert(t, !ok) - - // case3: specified method retry with resp, but method not match - atomic.StoreInt32(&callTimes, 0) - shouldResultRetry = &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() != method { - if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp { - return true - } - } - return false - }} - rc = NewRetryContainer() - err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) - test.Assert(t, err == nil, err) - ri = genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == retryResp, retryResult) - test.Assert(t, ok) - _, ok = ri.To().Tag(remoteTagKey) - test.Assert(t, !ok) - - // case4: specified method retry with resp, only ctx has some info then retry - ctxKeyVal := "ctxKeyVal" - atomic.StoreInt32(&callTimes, 0) - shouldResultRetry2 := &ShouldResultRetry{RespRetryWithCtx: func(ctx context.Context, resp interface{}, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method && ctx.Value(ctxKeyVal) == ctxKeyVal { - if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp { - return true - } - } - return false - }} - err = rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry2) - test.Assert(t, err == nil, err) - ctx = context.WithValue(ctx, ctxKeyVal, ctxKeyVal) - ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, retryResult.GetResult() == noRetryResp, retryResult) - test.Assert(t, !ok) - v, ok = ri.To().Tag(remoteTagKey) - test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) -} - // test different method use different retry policy func TestDifferentMethodConfig(t *testing.T) { var callTimes int32 @@ -931,103 +464,6 @@ func TestDifferentMethodConfig(t *testing.T) { test.Assert(t, ok) } -func TestResultRetryWithPolicyChange(t *testing.T) { - rc := NewRetryContainer() - shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - err := rc.Init(nil, shouldResultRetry) - test.Assert(t, err == nil, err) - - // case 1: first time trigger NotifyPolicyChange, the `initRetryer` will be executed, check if the ShouldResultRetry is not nil - rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) - r := rc.getRetryer(context.Background(), genRPCInfo()) - fr, ok := r.(*failureRetryer) - test.Assert(t, ok) - test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) - - // case 2: second time trigger NotifyPolicyChange, the `UpdatePolicy` will be executed, check if the ShouldResultRetry is not nil - rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) - r = rc.getRetryer(context.Background(), genRPCInfo()) - fr, ok = r.(*failureRetryer) - test.Assert(t, ok) - test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) -} - -func TestResultRetryWithCtxWhenPolicyChange(t *testing.T) { - rc := NewRetryContainer() - shouldResultRetry := &ShouldResultRetry{ErrorRetryWithCtx: func(ctx context.Context, err error, ri rpcinfo.RPCInfo) bool { - if ri.To().Method() == method { - if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 { - return true - } - } - return false - }} - err := rc.Init(nil, shouldResultRetry) - test.Assert(t, err == nil, err) - - // case 1: first time trigger NotifyPolicyChange, the `initRetryer` will be executed, check if the ShouldResultRetry is not nil - rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) - r := rc.getRetryer(context.Background(), genRPCInfo()) - fr, ok := r.(*failureRetryer) - test.Assert(t, ok) - test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) - - // case 2: second time trigger NotifyPolicyChange, the `UpdatePolicy` will be executed, check if the ShouldResultRetry is not nil - rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy())) - r = rc.getRetryer(context.Background(), genRPCInfo()) - fr, ok = r.(*failureRetryer) - test.Assert(t, ok) - test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry) -} - -// test BackupPolicy call while rpcTime > delayTime -func TestBackupPolicyCall(t *testing.T) { - ctx := context.Background() - rc := NewRetryContainer() - err := rc.Init(map[string]Policy{Wildcard: { - Enable: true, - Type: 1, - BackupPolicy: &BackupPolicy{ - RetryDelayMS: 30, - StopPolicy: StopPolicy{ - MaxRetryTimes: 2, - DisableChainStop: false, - CBPolicy: CBPolicy{ - ErrorRate: defaultCBErrRate, - }, - }, - }, - }}, nil) - test.Assert(t, err == nil, err) - - callTimes := int32(0) - firstRI := genRPCInfo() - secondRI := genRPCInfoWithRemoteTag(remoteTags) - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, firstRI) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { - atomic.AddInt32(&callTimes, 1) - if atomic.LoadInt32(&callTimes) == 1 { - // mock timeout for the first request and get the response of the backup request. - time.Sleep(time.Millisecond * 50) - return firstRI, nil, nil - } - return secondRI, nil, nil - }, firstRI, nil) - test.Assert(t, err == nil, err) - test.Assert(t, atomic.LoadInt32(&callTimes) == 2) - test.Assert(t, !ok) - v, ok := ri.To().Tag(remoteTagKey) - test.Assert(t, ok) - test.Assert(t, v == remoteTagValue) -} - // test policy noRetry call func TestPolicyNoRetryCall(t *testing.T) { ctx := context.Background() @@ -1138,50 +574,9 @@ func retryCall(callTimes *int32, firstRI rpcinfo.RPCInfo, backup bool) RPCCallFu } } -func TestFailureRetryWithRPCInfo(t *testing.T) { - // failure retry - ctx := context.Background() - rc := NewRetryContainer() - - ri := genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - rpcinfo.Record(ctx, ri, stats.RPCStart, nil) - - // call with retry policy - var callTimes int32 - policy := BuildFailurePolicy(NewFailurePolicy()) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, false), ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, !ok) - test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo) - test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo) - test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888") - test.Assert(t, atomic.LoadInt32(&callTimes) == 2) -} - -func TestBackupRetryWithRPCInfo(t *testing.T) { - // backup retry - ctx := context.Background() - rc := NewRetryContainer() - - ri := genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - rpcinfo.Record(ctx, ri, stats.RPCStart, nil) - - // call with retry policy - var callTimes int32 - policy := BuildBackupRequest(NewBackupPolicy(10)) - ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, true), ri, nil) - test.Assert(t, err == nil, err) - test.Assert(t, !ok) - test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo) - test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo) - test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888") - test.Assert(t, atomic.LoadInt32(&callTimes) == 2) -} - type mockResult struct { - result mockResp + result mockResp + callTimes int32 sync.RWMutex } @@ -1190,18 +585,30 @@ type mockResp struct { msg string } -func (r *mockResult) GetResult() interface{} { +func (r *mockResult) getResult() interface{} { r.RLock() defer r.RUnlock() return r.result } -func (r *mockResult) SetResult(ret mockResp) { +func (r *mockResult) setResult(ret mockResp) { r.Lock() defer r.Unlock() r.result = ret } +func (r *mockResult) setCallTimes(ct int32) { + r.Lock() + defer r.Unlock() + r.callTimes = ct +} + +func (r *mockResult) getCallTimes() int32 { + r.RLock() + defer r.RUnlock() + return r.callTimes +} + func TestNewRetryContainerWithOptions(t *testing.T) { t.Run("no_option", func(t *testing.T) { rc := NewRetryContainer() diff --git a/pkg/retry/util.go b/pkg/retry/util.go index 3d00717153..8c4731e312 100644 --- a/pkg/retry/util.go +++ b/pkg/retry/util.go @@ -127,7 +127,7 @@ func makeRetryErr(ctx context.Context, msg string, callTimes int32) error { ri := rpcinfo.GetRPCInfo(ctx) to := ri.To() - errMsg := fmt.Sprintf("retry[%d] failed, %s, to=%s, method=%s", callTimes-1, msg, to.ServiceName(), to.Method()) + errMsg := fmt.Sprintf("retry[%d] failed, %s, toService=%s, method=%s", callTimes-1, msg, to.ServiceName(), to.Method()) target := to.Address() if target != nil { errMsg = fmt.Sprintf("%s, remote=%s", errMsg, target.String()) From dd77fc2fe62c48f45ee5f44b7eb5a72fbf2ebde8 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Wed, 28 Aug 2024 21:12:08 +0800 Subject: [PATCH 59/70] chore: upgrade go version to solve scenario test issue (#1517) --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 481f69f6af..0479da47a7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.20' + go-version: '1.21' - name: Scenario Tests run: | cd .. From 92282c76bb34b97a6cec62dd9307b8ca05935a5d Mon Sep 17 00:00:00 2001 From: Jayant Date: Thu, 29 Aug 2024 15:18:01 +0800 Subject: [PATCH 60/70] refactor: thrift and generic codec uses bufiox interface for encoding and decoding (#1512) --- go.mod | 2 +- go.sum | 4 +- internal/generic/thrift/binary.go | 4 +- internal/generic/thrift/http.go | 37 +- internal/generic/thrift/http_fallback.go | 4 +- .../generic/thrift/http_go116plus_amd64.go | 31 +- internal/generic/thrift/http_pb.go | 19 +- internal/generic/thrift/json.go | 35 +- internal/generic/thrift/json_fallback.go | 7 +- .../generic/thrift/json_go116plus_amd64.go | 51 +- internal/generic/thrift/read.go | 36 +- internal/generic/thrift/read_test.go | 448 ++-------------- internal/generic/thrift/struct.go | 19 +- internal/generic/thrift/thrift.go | 6 +- internal/generic/thrift/write.go | 232 +++++---- internal/generic/thrift/write_test.go | 276 +++------- internal/mocks/conn.go | 13 + internal/mocks/generic/thrift.go | 22 +- internal/mocks/remote/bytebuf.go | 52 +- internal/mocks/update.sh | 2 +- pkg/generic/binarythrift_codec.go | 9 +- pkg/generic/binarythrift_codec_test.go | 33 +- pkg/generic/generic_service.go | 10 +- pkg/generic/generic_service_test.go | 32 +- pkg/protocol/bthrift/apache/apache.go | 28 +- .../bthrift/apache/binary_protocol.go | 370 +++---------- pkg/protocol/bthrift/binary.go | 490 +++--------------- pkg/protocol/bthrift/compat.go | 70 --- pkg/protocol/bthrift/utils.go | 38 -- pkg/remote/bufiox2buffer.go | 61 +++ pkg/remote/bytebuf.go | 6 +- pkg/remote/codec/default_codec.go | 14 +- pkg/remote/codec/default_codec_test.go | 2 +- pkg/remote/codec/header_codec_test.go | 20 +- pkg/remote/codec/thrift/deprecated.go | 26 +- pkg/remote/codec/thrift/deprecated_test.go | 76 +-- pkg/remote/codec/thrift/thrift.go | 24 +- pkg/remote/codec/thrift/thrift_data.go | 25 +- pkg/remote/codec/thrift/thrift_data_test.go | 43 +- pkg/remote/codec/thrift/thrift_frugal.go | 5 +- pkg/remote/codec/thrift/thrift_frugal_test.go | 32 +- pkg/remote/codec/thrift/thrift_test.go | 152 +++--- pkg/remote/default_bytebuf.go | 14 +- pkg/remote/default_bytebuf_test.go | 10 +- pkg/remote/trans/gonet/bytebuffer.go | 14 +- pkg/remote/trans/gonet/bytebuffer_test.go | 10 +- pkg/remote/trans/invoke/message_test.go | 3 +- pkg/remote/trans/netpoll/bytebuf.go | 16 +- pkg/remote/trans/netpoll/bytebuf_test.go | 10 +- .../trans/netpoll/http_client_handler_test.go | 3 +- pkg/remote/trans/nphttp2/buffer.go | 4 +- .../trans/nphttp2/server_handler_test.go | 16 +- pkg/utils/thrift.go | 22 +- pkg/utils/thrift_test.go | 118 ----- server/invoke.go | 48 +- server/invoke_test.go | 9 +- 56 files changed, 987 insertions(+), 2176 deletions(-) delete mode 100644 pkg/protocol/bthrift/compat.go delete mode 100644 pkg/protocol/bthrift/utils.go create mode 100644 pkg/remote/bufiox2buffer.go delete mode 100644 pkg/utils/thrift_test.go diff --git a/go.mod b/go.mod index fcfc4bf970..9df77e662e 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudwego/dynamicgo v0.3.0 github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d + github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487 github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 3ab2eee836..8a9617da8b 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d h1:QBV/89XA0Mwlk6LQgLIDIf1vDMWSn9O2Xx1lJX7PRGI= -github.com/cloudwego/gopkg v0.1.1-0.20240816085453-9fbe8155005d/go.mod h1:32yKw2zkpTMtuX6amJR0EMK79f0vGPr67UcArCOlZLU= +github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487 h1:JmCA5LJYdhLY8/TfngV/DXBtu8IsLBuo0tu+dfN5iQk= +github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= diff --git a/internal/generic/thrift/binary.go b/internal/generic/thrift/binary.go index d4aa2e3b2b..a30f4676e3 100644 --- a/internal/generic/thrift/binary.go +++ b/internal/generic/thrift/binary.go @@ -18,8 +18,8 @@ package thrift import ( "context" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/base" ) @@ -30,6 +30,6 @@ func NewWriteBinary() *WriteBinary { return &WriteBinary{} } -func (w *WriteBinary) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (w *WriteBinary) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { return nil } diff --git a/internal/generic/thrift/http.go b/internal/generic/thrift/http.go index fd31bfcbce..9f72931004 100644 --- a/internal/generic/thrift/http.go +++ b/internal/generic/thrift/http.go @@ -19,19 +19,17 @@ package thrift import ( "context" "fmt" - "io" "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" jsoniter "github.com/json-iterator/go" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) type HTTPReaderWriter struct { @@ -80,7 +78,7 @@ func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.O } // originalWrite ... -func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out io.Writer, msg interface{}, requestBase *base.Base) error { +func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out bufiox.Writer, msg interface{}, requestBase *base.Base) error { req := msg.(*descriptor.HTTPRequest) if req.Body == nil && len(req.RawBody) != 0 { if err := customJson.Unmarshal(req.RawBody, &req.Body); err != nil { @@ -94,11 +92,9 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out io.Writer, msg if !fn.HasRequestBase { requestBase = nil } - binaryWriter := thrift.NewBinaryWriter() - if err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}); err != nil { - return err - } - _, err = out.Write(binaryWriter.Bytes()) + bw := thrift.NewBufferWriter(out) + err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}) + bw.Recycle() return err } @@ -137,18 +133,15 @@ func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options) { } // Read ... -func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { - buffer, ok := in.(remote.ByteBuffer) - if !ok { - return nil, perrors.NewProtocolErrorWithMsg("io.Reader should be ByteBuffer") - } - binaryReader := thrift.NewBinaryReader(buffer) - +func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in bufiox.Reader) (interface{}, error) { // fallback logic if !r.dynamicgoEnabled || dataLen == 0 { - return r.originalRead(ctx, method, binaryReader) + return r.originalRead(ctx, method, in) } + binaryReader := thrift.NewBufferReader(in) + defer binaryReader.Recycle() + // dynamicgo logic // TODO: support exception field _, id, err := binaryReader.ReadFieldBegin() @@ -156,7 +149,9 @@ func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient boo return nil, err } bProt := &thrift.BinaryProtocol{} - transBuf, err := buffer.ReadBinary(dataLen - bProt.FieldBeginLength()) + l := dataLen - bProt.FieldBeginLength() + transBuf := dirtmake.Bytes(l, l) + _, err = in.ReadBinary(transBuf) if err != nil { return nil, err } @@ -186,13 +181,15 @@ func (r *ReadHTTPResponse) Read(ctx context.Context, method string, isClient boo return resp, nil } -func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in *thrift.BinaryReader) (interface{}, error) { +func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in bufiox.Reader) (interface{}, error) { fnDsc, err := r.svc.LookupFunctionByMethod(method) if err != nil { return nil, err } fDsc := fnDsc.Response - resp, err := skipStructReader(ctx, in, fDsc, &readerOption{forJSON: true, http: true, binaryWithBase64: r.base64Binary}) + br := thrift.NewBufferReader(in) + defer br.Recycle() + resp, err := skipStructReader(ctx, br, fDsc, &readerOption{forJSON: true, http: true, binaryWithBase64: r.base64Binary}) if r.useRawBodyForHTTPResp { if httpResp, ok := resp.(*descriptor.HTTPResponse); ok && httpResp.Body != nil { rawBody, err := customJson.Marshal(httpResp.Body) diff --git a/internal/generic/thrift/http_fallback.go b/internal/generic/thrift/http_fallback.go index 4fa510bbed..b6ea961a03 100644 --- a/internal/generic/thrift/http_fallback.go +++ b/internal/generic/thrift/http_fallback.go @@ -21,12 +21,12 @@ package thrift import ( "context" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/base" ) // Write ... -func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (w *WriteHTTPRequest) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { return w.originalWrite(ctx, out, msg, requestBase) } diff --git a/internal/generic/thrift/http_go116plus_amd64.go b/internal/generic/thrift/http_go116plus_amd64.go index c825890d84..94a04edfe2 100644 --- a/internal/generic/thrift/http_go116plus_amd64.go +++ b/internal/generic/thrift/http_go116plus_amd64.go @@ -22,13 +22,13 @@ package thrift import ( "context" "fmt" - "io" "unsafe" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/j2t" dbase "github.com/cloudwego/dynamicgo/thrift/base" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" @@ -36,12 +36,15 @@ import ( ) // Write ... -func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (w *WriteHTTPRequest) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { // fallback logic if !w.dynamicgoEnabled { return w.originalWrite(ctx, out, msg, requestBase) } + binaryWriter := thrift.NewBufferWriter(out) + defer binaryWriter.Recycle() + // dynamicgo logic req := msg.(*descriptor.HTTPRequest) @@ -63,29 +66,25 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interfa cv = j2t.NewBinaryConv(w.convOpts) } - binaryWriter := thrift.NewBinaryWriter() - ctx = context.WithValue(ctx, conv.CtxKeyHTTPRequest, req) body := req.GetBody() dbuf := mcache.Malloc(len(body))[0:0] defer mcache.Free(dbuf) for _, field := range dynamicgoTypeDsc.Struct().Fields() { - binaryWriter.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID())) - + if err := binaryWriter.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID())); err != nil { + return err + } // json []byte to thrift []byte if err := cv.DoInto(ctx, field.Type(), body, &dbuf); err != nil { return err } + if wb, err := out.Malloc(len(dbuf)); err != nil { + return err + } else { + copy(wb, dbuf) + } + dbuf = dbuf[:0] } - if _, err := out.Write(binaryWriter.Bytes()); err != nil { - return err - } - if _, err := out.Write(dbuf); err != nil { - return err - } - binaryWriter.Reset() - binaryWriter.WriteFieldStop() - _, err := out.Write(binaryWriter.Bytes()) - return err + return binaryWriter.WriteFieldStop() } diff --git a/internal/generic/thrift/http_pb.go b/internal/generic/thrift/http_pb.go index 8e3e4c221c..869400b1de 100644 --- a/internal/generic/thrift/http_pb.go +++ b/internal/generic/thrift/http_pb.go @@ -20,8 +20,8 @@ import ( "context" "errors" "fmt" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/jhump/protoreflect/desc" @@ -55,7 +55,7 @@ func NewWriteHTTPPbRequest(svc *descriptor.ServiceDescriptor, pbSvc *desc.Servic } // Write ... -func (w *WriteHTTPPbRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (w *WriteHTTPPbRequest) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { req := msg.(*descriptor.HTTPRequest) fn, err := w.svc.Router.Lookup(req) if err != nil { @@ -77,11 +77,9 @@ func (w *WriteHTTPPbRequest) Write(ctx context.Context, out io.Writer, msg inter } req.GeneralBody = pbMsg - binaryWriter := thrift.NewBinaryWriter() - if err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase}); err != nil { - return err - } - _, err = out.Write(binaryWriter.Bytes()) + binaryWriter := thrift.NewBufferWriter(out) + err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase}) + binaryWriter.Recycle() return err } @@ -100,7 +98,7 @@ func NewReadHTTPPbResponse(svc *descriptor.ServiceDescriptor, pbSvc proto.Servic } // Read ... -func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { +func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient bool, dataLen int, in bufiox.Reader) (interface{}, error) { fnDsc, err := r.svc.LookupFunctionByMethod(method) if err != nil { return nil, err @@ -111,5 +109,8 @@ func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, isClient b return nil, errors.New("pb method not found") } - return skipStructReader(ctx, thrift.NewBinaryReader(in), fDsc, &readerOption{pbDsc: mt.GetOutputType(), http: true}) + br := thrift.NewBufferReader(in) + resp, err := skipStructReader(ctx, br, fDsc, &readerOption{pbDsc: mt.GetOutputType(), http: true}) + br.Recycle() + return resp, err } diff --git a/internal/generic/thrift/json.go b/internal/generic/thrift/json.go index d4eb7bf082..4ba24e3896 100644 --- a/internal/generic/thrift/json.go +++ b/internal/generic/thrift/json.go @@ -19,20 +19,19 @@ package thrift import ( "context" "fmt" - "io" "strconv" "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" jsoniter "github.com/json-iterator/go" "github.com/tidwall/gjson" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/utils" ) @@ -83,7 +82,7 @@ func (m *WriteJSON) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) m.dynamicgoEnabled = true } -func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (m *WriteJSON) originalWrite(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) if err != nil { return fmt.Errorf("missing method: %s in service: %s", method, m.svcDsc.Name) @@ -98,14 +97,14 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interf requestBase = nil } - binaryWriter := thrift.NewBinaryWriter() + bw := thrift.NewBufferWriter(out) + defer bw.Recycle() // msg is void or nil if _, ok := msg.(descriptor.Void); ok || msg == nil { - if err = wrapStructWriter(ctx, msg, binaryWriter, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil { + if err = wrapStructWriter(ctx, msg, bw, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil { return err } - _, err = out.Write(binaryWriter.Bytes()) return err } @@ -125,10 +124,9 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interf Index: 0, } } - if err = wrapJSONWriter(ctx, &body, binaryWriter, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil { + if err = wrapJSONWriter(ctx, &body, bw, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil { return err } - _, err = out.Write(binaryWriter.Bytes()) return err } @@ -166,16 +164,10 @@ func (m *ReadJSON) SetDynamicGo(convOpts, convOptsWithException *conv.Options) { } // Read read data from in thrift.TProtocol and convert to json string -func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { - buffer, ok := in.(remote.ByteBuffer) - if !ok { - return nil, perrors.NewProtocolErrorWithMsg("io.Reader should be ByteBuffer") - } - binaryReader := thrift.NewBinaryReader(in) - +func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataLen int, in bufiox.Reader) (interface{}, error) { // fallback logic if !m.dynamicgoEnabled || dataLen <= 0 { - return m.originalRead(ctx, method, isClient, binaryReader) + return m.originalRead(ctx, method, isClient, in) } fnDsc := m.svc.DynamicGoDsc.Functions()[method] @@ -189,12 +181,13 @@ func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataL var resp interface{} if tyDsc.Struct().Fields()[0].Type().Type() == dthrift.VOID { - if _, err := buffer.ReadBinary(voidWholeLen); err != nil { + if err := in.Skip(voidWholeLen); err != nil { return nil, err } resp = descriptor.Void{} } else { - transBuff, err := buffer.ReadBinary(dataLen) + transBuff := dirtmake.Bytes(dataLen, dataLen) + _, err := in.ReadBinary(transBuff) if err != nil { return nil, err } @@ -225,7 +218,7 @@ func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataL return resp, nil } -func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, in *thrift.BinaryReader) (interface{}, error) { +func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, in bufiox.Reader) (interface{}, error) { fnDsc, err := m.svc.LookupFunctionByMethod(method) if err != nil { return nil, err @@ -234,7 +227,9 @@ func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient boo if !isClient { fDsc = fnDsc.Request } - resp, err := skipStructReader(ctx, in, fDsc, &readerOption{forJSON: true, throwException: true, binaryWithBase64: m.binaryWithBase64}) + br := thrift.NewBufferReader(in) + defer br.Recycle() + resp, err := skipStructReader(ctx, br, fDsc, &readerOption{forJSON: true, throwException: true, binaryWithBase64: m.binaryWithBase64}) if err != nil { return nil, err } diff --git a/internal/generic/thrift/json_fallback.go b/internal/generic/thrift/json_fallback.go index e1500fa853..5bc204f2c9 100644 --- a/internal/generic/thrift/json_fallback.go +++ b/internal/generic/thrift/json_fallback.go @@ -21,12 +21,13 @@ package thrift import ( "context" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/base" ) // Write write json string to out thrift.TProtocol -func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { - return m.originalWrite(ctx, out, msg, method, isClient, requestBase) +func (m *WriteJSON) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { + err := m.originalWrite(ctx, out, msg, method, isClient, requestBase) + return err } diff --git a/internal/generic/thrift/json_go116plus_amd64.go b/internal/generic/thrift/json_go116plus_amd64.go index d321ea8afc..42192a863c 100644 --- a/internal/generic/thrift/json_go116plus_amd64.go +++ b/internal/generic/thrift/json_go116plus_amd64.go @@ -22,7 +22,6 @@ package thrift import ( "context" "fmt" - "io" "unsafe" "github.com/bytedance/gopkg/lang/mcache" @@ -30,6 +29,7 @@ import ( "github.com/cloudwego/dynamicgo/conv/j2t" dthrift "github.com/cloudwego/dynamicgo/thrift" dbase "github.com/cloudwego/dynamicgo/thrift/base" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" @@ -39,7 +39,7 @@ import ( ) // Write write json string to out thrift.TProtocol -func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (m *WriteJSON) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { // fallback logic if !m.dynamicgoEnabled { return m.originalWrite(ctx, out, msg, method, isClient, requestBase) @@ -68,16 +68,9 @@ func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, m cv = j2t.NewBinaryConv(m.convOpts) } - binaryWriter := thrift.NewBinaryWriter() - // msg is void or nil if _, ok := msg.(descriptor.Void); ok || msg == nil { - if err := m.writeFields(ctx, out, dynamicgoTypeDsc, nil, nil, isClient); err != nil { - return err - } - binaryWriter.WriteFieldStop() - _, err := out.Write(binaryWriter.Bytes()) - return err + return m.writeFields(ctx, out, dynamicgoTypeDsc, nil, nil, isClient) } // msg is string @@ -87,22 +80,17 @@ func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, m } transBuff := utils.StringToSliceByte(s) - if err := m.writeFields(ctx, out, dynamicgoTypeDsc, &cv, transBuff, isClient); err != nil { - return err - } - binaryWriter.WriteFieldStop() - _, err := out.Write(binaryWriter.Bytes()) - return err + return m.writeFields(ctx, out, dynamicgoTypeDsc, &cv, transBuff, isClient) } type MsgType int -func (m *WriteJSON) writeFields(ctx context.Context, out io.Writer, dynamicgoTypeDsc *dthrift.TypeDescriptor, cv *j2t.BinaryConv, transBuff []byte, isClient bool) error { +func (m *WriteJSON) writeFields(ctx context.Context, out bufiox.Writer, dynamicgoTypeDsc *dthrift.TypeDescriptor, cv *j2t.BinaryConv, transBuff []byte, isClient bool) error { dbuf := mcache.Malloc(len(transBuff))[0:0] defer mcache.Free(dbuf) - binaryWriter := thrift.NewBinaryWriter() - + bw := thrift.NewBufferWriter(out) + defer bw.Recycle() for _, field := range dynamicgoTypeDsc.Struct().Fields() { // Exception field if !isClient && field.ID() != 0 { @@ -111,23 +99,28 @@ func (m *WriteJSON) writeFields(ctx context.Context, out io.Writer, dynamicgoTyp continue } - binaryWriter.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID())) - // if the field type is void, just write void and return - if field.Type().Type() == dthrift.VOID { - binaryWriter.WriteFieldStop() - _, err := out.Write(binaryWriter.Bytes()) + if err := bw.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID())); err != nil { return err + } + // if the field type is void, break + if field.Type().Type() == dthrift.VOID { + if err := bw.WriteFieldStop(); err != nil { + return err + } + break } else { // encode using dynamicgo // json []byte to thrift []byte if err := cv.DoInto(ctx, field.Type(), transBuff, &dbuf); err != nil { return err } + if wb, err := out.Malloc(len(dbuf)); err != nil { + return err + } else { + copy(wb, dbuf) + } + dbuf = dbuf[:0] } } - if _, err := out.Write(binaryWriter.Bytes()); err != nil { - return err - } - _, err := out.Write(dbuf) - return err + return bw.WriteFieldStop() } diff --git a/internal/generic/thrift/read.go b/internal/generic/thrift/read.go index 4b4bb3378c..4742bebd13 100644 --- a/internal/generic/thrift/read.go +++ b/internal/generic/thrift/read.go @@ -48,7 +48,7 @@ type readerOption struct { pbDsc proto.MessageDescriptor } -type reader func(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) +type reader func(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) type fieldSetter func(field *descriptor.FieldDescriptor, val interface{}) error @@ -109,7 +109,7 @@ func nextReader(tt descriptor.Type, t *descriptor.TypeDescriptor, opt *readerOpt } // TODO(marina.sakai): Optimize generic reader -func skipStructReader(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func skipStructReader(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { var v interface{} for { fieldType, fieldID, err := in.ReadFieldBegin() @@ -152,20 +152,20 @@ func skipStructReader(ctx context.Context, in *thrift.BinaryReader, t *descripto return v, nil } -func readVoid(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readVoid(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { _, err := readStruct(ctx, in, t, opt) return descriptor.Void{}, err } -func readDouble(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readDouble(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadDouble() } -func readBool(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readBool(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadBool() } -func readByte(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readByte(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { res, err := in.ReadByte() if err != nil { return nil, err @@ -176,7 +176,7 @@ func readByte(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDe return res, nil } -func readInt16(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInt16(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { res, err := in.ReadI16() if err != nil { return nil, err @@ -187,19 +187,19 @@ func readInt16(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeD return res, nil } -func readInt32(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInt32(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadI32() } -func readInt64(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInt64(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadI64() } -func readString(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readString(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { return in.ReadString() } -func readBinary(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readBinary(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { bytes, err := in.ReadBinary() if err != nil { return "", err @@ -207,7 +207,7 @@ func readBinary(ctx context.Context, in *thrift.BinaryReader, t *descriptor.Type return bytes, nil } -func readBase64Binary(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readBase64Binary(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { bytes, err := in.ReadBinary() if err != nil { return "", err @@ -215,7 +215,7 @@ func readBase64Binary(ctx context.Context, in *thrift.BinaryReader, t *descripto return base64.StdEncoding.EncodeToString(bytes), nil } -func readList(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readList(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { elemType, length, err := in.ReadListBegin() if err != nil { return nil, err @@ -236,14 +236,14 @@ func readList(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDe return l, nil } -func readMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readMap(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { if opt != nil && opt.forJSON { return readStringMap(ctx, in, t, opt) } return readInterfaceMap(ctx, in, t, opt) } -func readInterfaceMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readInterfaceMap(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { keyType, elemType, length, err := in.ReadMapBegin() if err != nil { return nil, err @@ -280,7 +280,7 @@ func readInterfaceMap(ctx context.Context, in *thrift.BinaryReader, t *descripto return m, nil } -func readStringMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readStringMap(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { keyType, elemType, length, err := in.ReadMapBegin() if err != nil { return nil, err @@ -313,7 +313,7 @@ func readStringMap(ctx context.Context, in *thrift.BinaryReader, t *descriptor.T return m, nil } -func readStruct(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readStruct(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { var fs fieldSetter var st interface{} if opt == nil || opt.pbDsc == nil { @@ -411,7 +411,7 @@ func readStruct(ctx context.Context, in *thrift.BinaryReader, t *descriptor.Type } } -func readHTTPResponse(ctx context.Context, in *thrift.BinaryReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { +func readHTTPResponse(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { var resp *descriptor.HTTPResponse if opt == nil || opt.pbDsc == nil { if opt == nil { diff --git a/internal/generic/thrift/read_test.go b/internal/generic/thrift/read_test.go index 2a77e120ee..b338d1dca3 100644 --- a/internal/generic/thrift/read_test.go +++ b/internal/generic/thrift/read_test.go @@ -23,13 +23,13 @@ import ( "reflect" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/jhump/protoreflect/desc/protoparse" "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic/descriptor" - "github.com/cloudwego/kitex/pkg/remote" ) var ( @@ -66,284 +66,12 @@ func Test_nextReader(t *testing.T) { } } -func Test_readVoid(t *testing.T) { +func TestReadSimple(t *testing.T) { type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"void", args{t: &descriptor.TypeDescriptor{Type: descriptor.VOID, Struct: &descriptor.StructDescriptor{}}}, descriptor.Void{}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeVoid(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeVoid() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readVoid(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readVoid() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readVoid() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readDouble(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readDouble", args{t: &descriptor.TypeDescriptor{Type: descriptor.DOUBLE}}, 1.0, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeFloat64(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeFloat64() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readDouble(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readDouble() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readDouble() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readBool(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readBool", args{t: &descriptor.TypeDescriptor{Type: descriptor.BOOL}}, true, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeBool(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeBool() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readBool(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readBool() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readBool() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readByte(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readByte", args{t: &descriptor.TypeDescriptor{Type: descriptor.BYTE}, opt: &readerOption{}}, int8(1), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeInt8(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeInt8() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readByte(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readByte() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readByte() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readInt16(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readInt16", args{t: &descriptor.TypeDescriptor{Type: descriptor.I16}, opt: &readerOption{}}, int16(1), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeInt16(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeInt16() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readInt16(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readInt16() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readInt16() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readInt32(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readInt32", args{t: &descriptor.TypeDescriptor{Type: descriptor.I32}}, int32(1), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeInt32(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeInt32() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readInt32(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readInt32() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readInt32() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readInt64(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readInt64", args{t: &descriptor.TypeDescriptor{Type: descriptor.I64}}, int64(1), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeInt64(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeInt64() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readInt64(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readInt64() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readInt64() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readString(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readString", args{t: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, stringInput, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeString(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeString() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readString(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readString() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readString() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readBinary64String(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption + t *descriptor.TypeDescriptor + writeFunc func(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error + readFunc func(ctx context.Context, in *thrift.BufferReader, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) + opt *readerOption } tests := []struct { name string @@ -351,153 +79,60 @@ func Test_readBinary64String(t *testing.T) { want interface{} wantErr bool }{ - // TODO: Add test cases. - {"readBase64Binary", args{t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, base64.StdEncoding.EncodeToString(binaryInput), false}, // read base64 string from binary field - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeBase64Binary(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeBase64Binary() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readBase64Binary(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readString() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readString() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readBinary(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readBinary", args{t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, binaryInput, false}, // read base64 string from binary field - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeBinary(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeBinary() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readBinary(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readString() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readString() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readList(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - want interface{} - wantErr bool - }{ - // TODO: Add test cases. - {"readList", args{t: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}}, []interface{}{stringInput, stringInput, stringInput}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := writeList(context.Background(), tt.want, w, tt.args.t, &writerOption{}) - if err != nil { - t.Errorf("writeList() error = %v", err) - } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readList(context.Background(), in, tt.args.t, tt.args.opt) - if (err != nil) != tt.wantErr { - t.Errorf("readList() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readList() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_readMap(t *testing.T) { - type args struct { - t *descriptor.TypeDescriptor - opt *readerOption - } - tests := []struct { - name string - args args - writer func(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error - want interface{} - wantErr bool - }{ - // TODO: Add test cases. + {"void", args{readFunc: readVoid, writeFunc: writeVoid, t: &descriptor.TypeDescriptor{Type: descriptor.VOID, Struct: &descriptor.StructDescriptor{}}}, descriptor.Void{}, false}, + {"readDouble", args{readFunc: readDouble, writeFunc: writeFloat64, t: &descriptor.TypeDescriptor{Type: descriptor.DOUBLE}}, 1.0, false}, + {"readBool", args{readFunc: readBool, writeFunc: writeBool, t: &descriptor.TypeDescriptor{Type: descriptor.BOOL}}, true, false}, + {"readByte", args{readFunc: readByte, writeFunc: writeInt8, t: &descriptor.TypeDescriptor{Type: descriptor.BYTE}, opt: &readerOption{}}, int8(1), false}, + {"readInt16", args{readFunc: readInt16, writeFunc: writeInt16, t: &descriptor.TypeDescriptor{Type: descriptor.I16}, opt: &readerOption{}}, int16(1), false}, + {"readInt32", args{readFunc: readInt32, writeFunc: writeInt32, t: &descriptor.TypeDescriptor{Type: descriptor.I32}}, int32(1), false}, + {"readInt64", args{readFunc: readInt64, writeFunc: writeInt64, t: &descriptor.TypeDescriptor{Type: descriptor.I64}}, int64(1), false}, + {"readString", args{readFunc: readString, writeFunc: writeString, t: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, stringInput, false}, + {"readBase64Binary", args{readFunc: readBase64Binary, writeFunc: writeBase64Binary, t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, base64.StdEncoding.EncodeToString(binaryInput), false}, // read base64 string from binary field + {"readBinary", args{readFunc: readBinary, writeFunc: writeBinary, t: &descriptor.TypeDescriptor{Name: "binary", Type: descriptor.STRING}}, binaryInput, false}, // read base64 string from binary field + {"readList", args{readFunc: readList, writeFunc: writeList, t: &descriptor.TypeDescriptor{Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}}}, []interface{}{stringInput, stringInput, stringInput}, false}, { "readMap", - args{t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{}}, - writeInterfaceMap, + args{readFunc: readMap, writeFunc: writeInterfaceMap, t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{}}, map[interface{}]interface{}{"hello": "world"}, false, }, { "readJsonMap", - args{t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{forJSON: true}}, - writeStringMap, + args{readFunc: readMap, writeFunc: writeStringMap, t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{forJSON: true}}, map[string]interface{}{"hello": "world"}, false, }, { "readJsonMapWithInt16Key", - args{t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.I16}, Elem: &descriptor.TypeDescriptor{Type: descriptor.BOOL}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{forJSON: true}}, - writeStringMap, + args{readFunc: readMap, writeFunc: writeStringMap, t: &descriptor.TypeDescriptor{Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.I16}, Elem: &descriptor.TypeDescriptor{Type: descriptor.BOOL}, Struct: &descriptor.StructDescriptor{}}, opt: &readerOption{forJSON: true}}, map[string]interface{}{"16": false}, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() - err := tt.writer(context.Background(), tt.want, w, tt.args.t, &writerOption{}) + var bs []byte + bw := bufiox.NewBytesWriter(&bs) + w := thrift.NewBufferWriter(bw) + err := tt.args.writeFunc(context.Background(), tt.want, w, tt.args.t, &writerOption{}) if err != nil { - t.Errorf("writeInterfaceMap() error = %v", err) + t.Errorf("writeFloat64() error = %v", err) } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) - got, err := readMap(context.Background(), in, tt.args.t, tt.args.opt) + _ = bw.Flush() + in := thrift.NewBufferReader(bufiox.NewBytesReader(bs)) + got, err := tt.args.readFunc(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { - t.Errorf("readMap() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("readVoid() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("readMap() = %v, want %v", got, tt.want) + t.Errorf("readVoid() = %v, want %v", got, tt.want) } }) } } -func Test_readStruct(t *testing.T) { +func TestReadStruct(t *testing.T) { type args struct { t *descriptor.TypeDescriptor opt *readerOption @@ -645,12 +280,15 @@ func Test_readStruct(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() + var bs []byte + bw := bufiox.NewBytesWriter(&bs) + w := thrift.NewBufferWriter(bw) err := writeStruct(context.Background(), tt.input, w, tt.args.t, &writerOption{}) if err != nil { t.Errorf("writeStruct() error = %v", err) } - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + _ = bw.Flush() + in := thrift.NewBufferReader(bufiox.NewBytesReader(bs)) got, err := readStruct(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readStruct() error = %v, wantErr %v", err, tt.wantErr) @@ -695,11 +333,14 @@ func Test_readHTTPResponse(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() + var bs []byte + bw := bufiox.NewBytesWriter(&bs) + w := thrift.NewBufferWriter(bw) w.WriteFieldBegin(thrift.TType(descriptor.STRING), 1) w.WriteString("world") w.WriteFieldStop() - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + _ = bw.Flush() + in := thrift.NewBufferReader(bufiox.NewBytesReader(bs)) got, err := readHTTPResponse(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readHTTPResponse() error = %v, wantErr %v", err, tt.wantErr) @@ -752,11 +393,14 @@ func Test_readHTTPResponseWithPbBody(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - w := thrift.NewBinaryWriter() + var bs []byte + bw := bufiox.NewBytesWriter(&bs) + w := thrift.NewBufferWriter(bw) w.WriteFieldBegin(thrift.TType(descriptor.STRING), 1) w.WriteString("hello world") w.WriteFieldStop() - in := thrift.NewBinaryReader(remote.NewReaderBuffer(w.Bytes())) + _ = bw.Flush() + in := thrift.NewBufferReader(bufiox.NewBytesReader(bs)) got, err := readHTTPResponse(context.Background(), in, tt.args.t, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("readHTTPResponse() error = %v, wantErr %v", err, tt.wantErr) diff --git a/internal/generic/thrift/struct.go b/internal/generic/thrift/struct.go index eaea4bfe08..e13ede046e 100644 --- a/internal/generic/thrift/struct.go +++ b/internal/generic/thrift/struct.go @@ -18,8 +18,8 @@ package thrift import ( "context" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" @@ -59,7 +59,7 @@ func (m *WriteStruct) SetBinaryWithBase64(enable bool) { } // Write ... -func (m *WriteStruct) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (m *WriteStruct) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { fnDsc, err := m.svcDsc.LookupFunctionByMethod(method) if err != nil { return err @@ -73,11 +73,9 @@ func (m *WriteStruct) Write(ctx context.Context, out io.Writer, msg interface{}, if !hasRequestBase { requestBase = nil } - binaryWriter := thrift.NewBinaryWriter() - if err = wrapStructWriter(ctx, msg, binaryWriter, ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64}); err != nil { - return err - } - _, err = out.Write(binaryWriter.Bytes()) + binaryWriter := thrift.NewBufferWriter(out) + err = wrapStructWriter(ctx, msg, binaryWriter, ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64}) + binaryWriter.Recycle() return err } @@ -121,7 +119,7 @@ func (m *ReadStruct) SetSetFieldsForEmptyStruct(mode uint8) { } // Read ... -func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { +func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dataLen int, in bufiox.Reader) (interface{}, error) { fnDsc, err := m.svc.LookupFunctionByMethod(method) if err != nil { return nil, err @@ -130,5 +128,8 @@ func (m *ReadStruct) Read(ctx context.Context, method string, isClient bool, dat if !isClient { fDsc = fnDsc.Request } - return skipStructReader(ctx, thrift.NewBinaryReader(in), fDsc, &readerOption{throwException: true, forJSON: m.forJSON, binaryWithBase64: m.binaryWithBase64, binaryWithByteSlice: m.binaryWithByteSlice, setFieldsForEmptyStruct: m.setFieldsForEmptyStruct}) + br := thrift.NewBufferReader(in) + str, err := skipStructReader(ctx, br, fDsc, &readerOption{throwException: true, forJSON: m.forJSON, binaryWithBase64: m.binaryWithBase64, binaryWithByteSlice: m.binaryWithByteSlice, setFieldsForEmptyStruct: m.setFieldsForEmptyStruct}) + br.Recycle() + return str, err } diff --git a/internal/generic/thrift/thrift.go b/internal/generic/thrift/thrift.go index 862691642c..6c72febfa5 100644 --- a/internal/generic/thrift/thrift.go +++ b/internal/generic/thrift/thrift.go @@ -19,8 +19,8 @@ package thrift import ( "context" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/base" ) @@ -30,10 +30,10 @@ const ( // MessageReader read from thrift.TProtocol with method type MessageReader interface { - Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) + Read(ctx context.Context, method string, isClient bool, dataLen int, in bufiox.Reader) (interface{}, error) } // MessageWriter write to thrift.TProtocol type MessageWriter interface { - Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error + Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error } diff --git a/internal/generic/thrift/write.go b/internal/generic/thrift/write.go index bb2c27a9aa..04c849e9ba 100644 --- a/internal/generic/thrift/write.go +++ b/internal/generic/thrift/write.go @@ -37,7 +37,7 @@ type writerOption struct { binaryWithBase64 bool } -type writer func(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error +type writer func(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error type fieldGetter func(val interface{}, field *descriptor.FieldDescriptor) (interface{}, bool) @@ -207,42 +207,32 @@ func nextJSONWriter(data *gjson.Result, t *descriptor.TypeDescriptor, opt *write return v, fn, nil } -func writeEmptyValue(out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeEmptyValue(out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { switch t.Type { case descriptor.BOOL: - out.WriteBool(false) - return nil + return out.WriteBool(false) case descriptor.I08: - out.WriteByte(0) - return nil + return out.WriteByte(0) case descriptor.I16: - out.WriteI16(0) - return nil + return out.WriteI16(0) case descriptor.I32: - out.WriteI32(0) - return nil + return out.WriteI32(0) case descriptor.I64: - out.WriteI64(0) - return nil + return out.WriteI64(0) case descriptor.DOUBLE: - out.WriteDouble(0) - return nil + return out.WriteDouble(0) case descriptor.STRING: if t.Name == "binary" && opt.binaryWithBase64 { - out.WriteBinary([]byte{}) + return out.WriteBinary([]byte{}) } else { - out.WriteString("") + return out.WriteString("") } - return nil case descriptor.LIST, descriptor.SET: - out.WriteListBegin(t.Elem.Type.ToThriftTType(), 0) - return nil + return out.WriteListBegin(t.Elem.Type.ToThriftTType(), 0) case descriptor.MAP: - out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), 0) - return nil + return out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), 0) case descriptor.STRUCT: - out.WriteFieldStop() - return nil + return out.WriteFieldStop() case descriptor.VOID: return nil } @@ -250,7 +240,7 @@ func writeEmptyValue(out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt } // TODO(marina.sakai): Optimize generic struct writer -func wrapStructWriter(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func wrapStructWriter(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { for name, field := range t.Struct.FieldsByName { if field.IsException { // generic server ignore the exception, because no description for exception @@ -258,7 +248,9 @@ func wrapStructWriter(ctx context.Context, val interface{}, out *thrift.BinaryWr continue } if val != nil { - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } writer, err := nextWriter(val, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) @@ -268,19 +260,23 @@ func wrapStructWriter(ctx context.Context, val interface{}, out *thrift.BinaryWr } } } - out.WriteFieldStop() + if err := out.WriteFieldStop(); err != nil { + return err + } return nil } // TODO(marina.sakai): Optimize generic json writer -func wrapJSONWriter(ctx context.Context, val *gjson.Result, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func wrapJSONWriter(ctx context.Context, val *gjson.Result, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { for name, field := range t.Struct.FieldsByName { if field.IsException { // generic server ignore the exception, because no description for exception // generic handler just return error continue } - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } v, writer, err := nextJSONWriter(val, field.Type, opt) if err != nil { return fmt.Errorf("nextJSONWriter of field[%s] error %w", name, err) @@ -289,20 +285,21 @@ func wrapJSONWriter(ctx context.Context, val *gjson.Result, out *thrift.BinaryWr return fmt.Errorf("writer of field[%s] error %w", name, err) } } - out.WriteFieldStop() + if err := out.WriteFieldStop(); err != nil { + return err + } return nil } -func writeVoid(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeVoid(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { return writeStruct(ctx, map[string]interface{}{}, out, t, opt) } -func writeBool(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { - out.WriteBool(val.(bool)) - return nil +func writeBool(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + return out.WriteBool(val.(bool)) } -func writeInt8(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt8(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { var i int8 switch val := val.(type) { case int8: @@ -315,22 +312,18 @@ func writeInt8(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t // compatible with lossless conversion switch t.Type { case descriptor.I08: - out.WriteByte(i) - return nil + return out.WriteByte(i) case descriptor.I16: - out.WriteI16(int16(i)) - return nil + return out.WriteI16(int16(i)) case descriptor.I32: - out.WriteI32(int32(i)) - return nil + return out.WriteI32(int32(i)) case descriptor.I64: - out.WriteI64(int64(i)) - return nil + return out.WriteI64(int64(i)) } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeInt16(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt16(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { // compatible with lossless conversion i := val.(int16) switch t.Type { @@ -338,22 +331,18 @@ func writeInt16(ctx context.Context, val interface{}, out *thrift.BinaryWriter, if i&0xff != i { return fmt.Errorf("value is beyond range of i8: %v", i) } - out.WriteByte(int8(i)) - return nil + return out.WriteByte(int8(i)) case descriptor.I16: - out.WriteI16(i) - return nil + return out.WriteI16(i) case descriptor.I32: - out.WriteI32(int32(i)) - return nil + return out.WriteI32(int32(i)) case descriptor.I64: - out.WriteI64(int64(i)) - return nil + return out.WriteI64(int64(i)) } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeInt32(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt32(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { // compatible with lossless conversion i := val.(int32) switch t.Type { @@ -361,25 +350,21 @@ func writeInt32(ctx context.Context, val interface{}, out *thrift.BinaryWriter, if i&0xff != i { return fmt.Errorf("value is beyond range of i8: %v", i) } - out.WriteByte(int8(i)) - return nil + return out.WriteByte(int8(i)) case descriptor.I16: if i&0xffff != i { return fmt.Errorf("value is beyond range of i16: %v", i) } - out.WriteI16(int16(i)) - return nil + return out.WriteI16(int16(i)) case descriptor.I32: - out.WriteI32(i) - return nil + return out.WriteI32(i) case descriptor.I64: - out.WriteI64(int64(i)) - return nil + return out.WriteI64(int64(i)) } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeInt64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInt64(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { // compatible with lossless conversion i := val.(int64) switch t.Type { @@ -387,28 +372,24 @@ func writeInt64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, if i&0xff != i { return fmt.Errorf("value is beyond range of i8: %v", i) } - out.WriteByte(int8(i)) - return nil + return out.WriteByte(int8(i)) case descriptor.I16: if i&0xffff != i { return fmt.Errorf("value is beyond range of i16: %v", i) } - out.WriteI16(int16(i)) - return nil + return out.WriteI16(int16(i)) case descriptor.I32: if i&0xffffffff != i { return fmt.Errorf("value is beyond range of i32: %v", i) } - out.WriteI32(int32(i)) - return nil + return out.WriteI32(int32(i)) case descriptor.I64: - out.WriteI64(i) - return nil + return out.WriteI64(i) } return fmt.Errorf("need int type, but got: %s", t.Type) } -func writeJSONNumber(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSONNumber(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { jn := val.(json.Number) switch t.Type { case descriptor.I08: @@ -445,7 +426,7 @@ func writeJSONNumber(ctx context.Context, val interface{}, out *thrift.BinaryWri return nil } -func writeJSONFloat64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSONFloat64(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { i := val.(float64) switch t.Type { case descriptor.I08: @@ -462,44 +443,46 @@ func writeJSONFloat64(ctx context.Context, val interface{}, out *thrift.BinaryWr return fmt.Errorf("need number type, but got: %s", t.Type) } -func writeFloat64(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { - out.WriteDouble(val.(float64)) - return nil +func writeFloat64(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + return out.WriteDouble(val.(float64)) } -func writeString(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { - out.WriteString(val.(string)) - return nil +func writeString(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + return out.WriteString(val.(string)) } -func writeBase64Binary(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeBase64Binary(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { bytes, err := base64.StdEncoding.DecodeString(val.(string)) if err != nil { return err } - out.WriteBinary(bytes) - return nil + return out.WriteBinary(bytes) } -func writeBinary(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { - out.WriteBinary(val.([]byte)) - return nil +func writeBinary(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { + return out.WriteBinary(val.([]byte)) } -func writeBinaryList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeBinaryList(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { l := val.([]byte) length := len(l) - out.WriteListBegin(t.Elem.Type.ToThriftTType(), length) + if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), length); err != nil { + return err + } for _, b := range l { - out.WriteByte(int8(b)) + if err := out.WriteByte(int8(b)); err != nil { + return err + } } return nil } -func writeList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeList(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { l := val.([]interface{}) length := len(l) - out.WriteListBegin(t.Elem.Type.ToThriftTType(), length) + if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), length); err != nil { + return err + } if length == 0 { return nil } @@ -526,10 +509,12 @@ func writeList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t return nil } -func writeJSONList(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSONList(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { l := val.([]gjson.Result) length := len(l) - out.WriteListBegin(t.Elem.Type.ToThriftTType(), length) + if err := out.WriteListBegin(t.Elem.Type.ToThriftTType(), length); err != nil { + return err + } if length == 0 { return nil } @@ -545,10 +530,12 @@ func writeJSONList(ctx context.Context, val interface{}, out *thrift.BinaryWrite return nil } -func writeInterfaceMap(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeInterfaceMap(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { m := val.(map[interface{}]interface{}) length := len(m) - out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length) + if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length); err != nil { + return err + } if length == 0 { return nil } @@ -584,10 +571,12 @@ func writeInterfaceMap(ctx context.Context, val interface{}, out *thrift.BinaryW return nil } -func writeStringMap(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeStringMap(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { m := val.(map[string]interface{}) length := len(m) - out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length) + if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length); err != nil { + return err + } if length == 0 { return nil } @@ -627,10 +616,12 @@ func writeStringMap(ctx context.Context, val interface{}, out *thrift.BinaryWrit return nil } -func writeStringJSONMap(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeStringJSONMap(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { m := val.(map[string]gjson.Result) length := len(m) - out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length) + if err := out.WriteMapBegin(t.Key.Type.ToThriftTType(), t.Elem.Type.ToThriftTType(), length); err != nil { + return err + } if length == 0 { return nil } @@ -664,7 +655,7 @@ func writeStringJSONMap(ctx context.Context, val interface{}, out *thrift.Binary return nil } -func writeRequestBase(ctx context.Context, val interface{}, out *thrift.BinaryWriter, field *descriptor.FieldDescriptor, opt *writerOption) error { +func writeRequestBase(ctx context.Context, val interface{}, out *thrift.BufferWriter, field *descriptor.FieldDescriptor, opt *writerOption) error { if st, ok := val.(map[string]interface{}); ok { // copy from user's Extra if ext, ok := st["Extra"]; ok { @@ -698,18 +689,22 @@ func writeRequestBase(ctx context.Context, val interface{}, out *thrift.BinaryWr } } } - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } sz := opt.requestBase.BLength() buf := make([]byte, sz) opt.requestBase.FastWrite(buf) for _, b := range buf { - out.WriteByte(int8(b)) + if err := out.WriteByte(int8(b)); err != nil { + return err + } } return nil } // writeStruct iter with Descriptor, can check the field's required and others -func writeStruct(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeStruct(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { var fg fieldGetter switch val.(type) { case map[string]interface{}: @@ -732,7 +727,9 @@ func writeStruct(ctx context.Context, val interface{}, out *thrift.BinaryWriter, if elem == nil || !ok { if !field.Optional { // empty fields don't need value-mapping here, since writeEmptyValue decides zero value based on Thrift type - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } if err := writeEmptyValue(out, field.Type, opt); err != nil { return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) } @@ -744,7 +741,9 @@ func writeStruct(ctx context.Context, val interface{}, out *thrift.BinaryWriter, return err } } - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } writer, err := nextWriter(elem, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) @@ -755,11 +754,10 @@ func writeStruct(ctx context.Context, val interface{}, out *thrift.BinaryWriter, } } - out.WriteFieldStop() - return nil + return out.WriteFieldStop() } -func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { req := val.(*descriptor.HTTPRequest) defer func() { if req.Params != nil { @@ -780,7 +778,9 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BinaryWr if v == nil { if !field.Optional { - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } if err := writeEmptyValue(out, field.Type, opt); err != nil { return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) } @@ -792,7 +792,9 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BinaryWr return err } } - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } writer, err := nextWriter(v, field.Type, opt) if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) @@ -803,11 +805,10 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BinaryWr } } - out.WriteFieldStop() - return nil + return out.WriteFieldStop() } -func writeJSON(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { +func writeJSON(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error { data := val.(*gjson.Result) for name, field := range t.Struct.FieldsByName { elem := data.Get(name) @@ -821,7 +822,9 @@ func writeJSON(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t if elem.Type == gjson.Null { if !field.Optional { - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } if err := writeEmptyValue(out, field.Type, opt); err != nil { return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) } @@ -831,13 +834,14 @@ func writeJSON(ctx context.Context, val interface{}, out *thrift.BinaryWriter, t if err != nil { return fmt.Errorf("nextWriter of field[%s] error %w", name, err) } - out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)) + if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { + return err + } if err := writer(ctx, v, out, field.Type, opt); err != nil { return fmt.Errorf("writer of field[%s] error %w", name, err) } } } - out.WriteFieldStop() - return nil + return out.WriteFieldStop() } diff --git a/internal/generic/thrift/write_test.go b/internal/generic/thrift/write_test.go index 9545fe66c5..79a469bdca 100644 --- a/internal/generic/thrift/write_test.go +++ b/internal/generic/thrift/write_test.go @@ -22,6 +22,7 @@ import ( "encoding/json" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/jhump/protoreflect/desc/protoparse" @@ -36,7 +37,6 @@ func Test_nextWriter(t *testing.T) { // add some testcases type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -51,7 +51,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri8 Success", args{ val: int8(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -67,7 +66,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri16 Success", args{ val: int16(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -83,7 +81,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri32 Success", args{ val: int32(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -99,7 +96,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri64 Success", args{ val: int64(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -115,7 +111,6 @@ func Test_nextWriter(t *testing.T) { "nextWriterbool Success", args{ val: true, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.BOOL, Struct: &descriptor.StructDescriptor{}, @@ -131,7 +126,6 @@ func Test_nextWriter(t *testing.T) { "nextWriterdouble Success", args{ val: float64(1.0), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -147,7 +141,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri8 Failed", args{ val: 10000000, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -163,7 +156,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri16 Failed", args{ val: 10000000, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -179,7 +171,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri32 Failed", args{ val: 10000000, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -195,7 +186,6 @@ func Test_nextWriter(t *testing.T) { "nextWriteri64 Failed", args{ val: "10000000", - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -222,88 +212,16 @@ func Test_nextWriter(t *testing.T) { t.Error("nextWriter() error = nil, but writerfunc == nil") return } - if err := writerfunc(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writerfunc(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writerfunc() error = %v, wantErr %v", err, tt.wantErr) } }) } } -func Test_writeVoid(t *testing.T) { - type args struct { - val interface{} - out *thrift.BinaryWriter - t *descriptor.TypeDescriptor - opt *writerOption - } - - tests := []struct { - name string - args args - wantErr bool - }{ - // TODO: Add test cases. - { - "writeVoid", - args{ - val: 1, - out: thrift.NewBinaryWriter(), - t: &descriptor.TypeDescriptor{ - Type: descriptor.VOID, - Struct: &descriptor.StructDescriptor{}, - }, - }, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := writeVoid(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { - t.Errorf("writeVoid() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func Test_writeBool(t *testing.T) { - type args struct { - val interface{} - out *thrift.BinaryWriter - t *descriptor.TypeDescriptor - opt *writerOption - } - tests := []struct { - name string - args args - wantErr bool - }{ - // TODO: Add test cases. - { - "writeBool", - args{ - val: true, - out: thrift.NewBinaryWriter(), - t: &descriptor.TypeDescriptor{ - Type: descriptor.BOOL, - Struct: &descriptor.StructDescriptor{}, - }, - }, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := writeBool(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { - t.Errorf("writeBool() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - func Test_writeInt8(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -317,7 +235,6 @@ func Test_writeInt8(t *testing.T) { "writeInt8", args{ val: int8(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -329,7 +246,6 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 byte", args: args{ val: byte(128), // overflow - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -341,7 +257,6 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 error", args: args{ val: int16(2), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -353,7 +268,6 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i16", args: args{ val: int8(2), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -365,7 +279,6 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i32", args: args{ val: int8(2), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -377,7 +290,6 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i64", args: args{ val: int8(2), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -389,7 +301,6 @@ func Test_writeInt8(t *testing.T) { name: "writeInt8 to i64", args: args{ val: int8(2), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -400,7 +311,7 @@ func Test_writeInt8(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeInt8(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeInt8(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeInt8() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -410,7 +321,6 @@ func Test_writeInt8(t *testing.T) { func Test_writeJSONNumber(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -424,7 +334,6 @@ func Test_writeJSONNumber(t *testing.T) { "writeJSONNumber", args{ val: json.Number("1"), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -435,7 +344,7 @@ func Test_writeJSONNumber(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeJSONNumber(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeJSONNumber(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeJSONNumber() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -445,7 +354,6 @@ func Test_writeJSONNumber(t *testing.T) { func Test_writeJSONFloat64(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -459,7 +367,6 @@ func Test_writeJSONFloat64(t *testing.T) { "writeJSONFloat64", args{ val: 1.0, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -471,7 +378,6 @@ func Test_writeJSONFloat64(t *testing.T) { "writeJSONFloat64 bool Failed", args{ val: 1.0, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.BOOL, Struct: &descriptor.StructDescriptor{}, @@ -482,7 +388,7 @@ func Test_writeJSONFloat64(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeJSONFloat64(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeJSONFloat64(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeJSONFloat64() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -492,7 +398,6 @@ func Test_writeJSONFloat64(t *testing.T) { func Test_writeInt16(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -506,7 +411,6 @@ func Test_writeInt16(t *testing.T) { "writeInt16", args{ val: int16(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -518,7 +422,6 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt8 Success", args{ val: int16(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -530,7 +433,6 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt8 Failed", args{ val: int16(10000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -542,7 +444,6 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt32 Success", args{ val: int16(10000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -554,7 +455,6 @@ func Test_writeInt16(t *testing.T) { "writeInt16toInt64 Success", args{ val: int16(10000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -566,7 +466,6 @@ func Test_writeInt16(t *testing.T) { "writeInt16 Failed", args{ val: int16(10000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -577,7 +476,7 @@ func Test_writeInt16(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeInt16(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeInt16(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeInt16() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -587,7 +486,6 @@ func Test_writeInt16(t *testing.T) { func Test_writeInt32(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -602,7 +500,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32 Success", args{ val: int32(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -614,7 +511,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32 Failed", args{ val: int32(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -626,7 +522,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt8 Success", args{ val: int32(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -638,7 +533,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt8 Failed", args{ val: int32(100000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -650,7 +544,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt16 success", args{ val: int32(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -662,7 +555,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt16 Failed", args{ val: int32(100000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -674,7 +566,6 @@ func Test_writeInt32(t *testing.T) { "writeInt32ToInt64 Success", args{ val: int32(10000000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -685,7 +576,7 @@ func Test_writeInt32(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeInt32(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeInt32(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeInt32() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -695,7 +586,6 @@ func Test_writeInt32(t *testing.T) { func Test_writeInt64(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -709,7 +599,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64 Success", args{ val: int64(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I64, Struct: &descriptor.StructDescriptor{}, @@ -721,7 +610,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64 Failed", args{ val: int64(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -733,7 +621,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt8 Success", args{ val: int64(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -745,7 +632,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt8 failed", args{ val: int64(1000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I08, Struct: &descriptor.StructDescriptor{}, @@ -757,7 +643,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt16 Success", args{ val: int64(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -769,7 +654,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt16 failed", args{ val: int64(100000000000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I16, Struct: &descriptor.StructDescriptor{}, @@ -781,7 +665,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt32 Success", args{ val: int64(1), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -793,7 +676,6 @@ func Test_writeInt64(t *testing.T) { "writeInt64ToInt32 failed", args{ val: int64(100000000000), - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.I32, Struct: &descriptor.StructDescriptor{}, @@ -804,7 +686,7 @@ func Test_writeInt64(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeInt64(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeInt64(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeInt64() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -814,7 +696,6 @@ func Test_writeInt64(t *testing.T) { func Test_writeFloat64(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -828,7 +709,6 @@ func Test_writeFloat64(t *testing.T) { "writeFloat64", args{ val: 1.0, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.DOUBLE, Struct: &descriptor.StructDescriptor{}, @@ -839,7 +719,7 @@ func Test_writeFloat64(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeFloat64(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeFloat64(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeFloat64() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -849,7 +729,6 @@ func Test_writeFloat64(t *testing.T) { func Test_writeString(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -863,7 +742,6 @@ func Test_writeString(t *testing.T) { "writeString", args{ val: stringInput, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.STRING, Struct: &descriptor.StructDescriptor{}, @@ -874,7 +752,7 @@ func Test_writeString(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeString(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeString(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeString() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -884,7 +762,7 @@ func Test_writeString(t *testing.T) { func Test_writeBase64String(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -898,7 +776,7 @@ func Test_writeBase64String(t *testing.T) { "writeBase64Binary", // write to binary field with base64 string args{ val: base64.StdEncoding.EncodeToString(binaryInput), - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Name: "binary", Type: descriptor.STRING, @@ -910,7 +788,7 @@ func Test_writeBase64String(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeBase64Binary(context.Background(), tt.args.val, tt.args.out, tt.args.t, + if err := writeBase64Binary(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeString() error = %v, wantErr %v", err, tt.wantErr) } @@ -921,7 +799,7 @@ func Test_writeBase64String(t *testing.T) { func Test_writeBinary(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -935,7 +813,7 @@ func Test_writeBinary(t *testing.T) { "writeBinary", args{ val: []byte(stringInput), - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRING, Struct: &descriptor.StructDescriptor{}, @@ -946,7 +824,7 @@ func Test_writeBinary(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeBinary(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeBinary(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeBinary() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -992,13 +870,16 @@ func Test_writeBinaryList(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - binaryWriter := thrift.NewBinaryWriter() - if err := writeBinaryList(context.Background(), tt.args.val, binaryWriter, tt.args.t, + var bs []byte + bw := bufiox.NewBytesWriter(&bs) + w := thrift.NewBufferWriter(bw) + if err := writeBinaryList(context.Background(), tt.args.val, w, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeBinary() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr { - test.Assert(t, len(tt.args.val)+5 == len(binaryWriter.Bytes())) + bw.Flush() + test.Assert(t, len(tt.args.val)+5 == len(bs)) } }) } @@ -1007,7 +888,7 @@ func Test_writeBinaryList(t *testing.T) { func Test_writeList(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1021,7 +902,7 @@ func Test_writeList(t *testing.T) { "writeList", args{ val: []interface{}{stringInput}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1034,7 +915,7 @@ func Test_writeList(t *testing.T) { "writeListWithNil", args{ val: []interface{}{stringInput, nil, stringInput}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1047,7 +928,7 @@ func Test_writeList(t *testing.T) { "writeListWithNilOnly", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1060,7 +941,7 @@ func Test_writeList(t *testing.T) { "writeListWithNextWriterError", args{ val: []interface{}{stringInput}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I08}, @@ -1072,7 +953,7 @@ func Test_writeList(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeList(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeList(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeList() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1082,7 +963,7 @@ func Test_writeList(t *testing.T) { func Test_writeInterfaceMap(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1096,7 +977,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMap", args{ val: map[interface{}]interface{}{"hello": "world"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1110,7 +991,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithNil", args{ val: map[interface{}]interface{}{"hello": "world", "hi": nil, "hey": "kitex"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1124,7 +1005,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithNilOnly", args{ val: map[interface{}]interface{}{"hello": nil}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1138,7 +1019,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithElemNextWriterError", args{ val: map[interface{}]interface{}{"hello": "world"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1152,7 +1033,7 @@ func Test_writeInterfaceMap(t *testing.T) { "writeInterfaceMapWithKeyWriterError", args{ val: map[interface{}]interface{}{"hello": "world"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.I08}, @@ -1165,7 +1046,7 @@ func Test_writeInterfaceMap(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeInterfaceMap(context.Background(), tt.args.val, tt.args.out, tt.args.t, + if err := writeInterfaceMap(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeInterfaceMap() error = %v, wantErr %v", err, tt.wantErr) } @@ -1176,7 +1057,7 @@ func Test_writeInterfaceMap(t *testing.T) { func Test_writeStringMap(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1190,7 +1071,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMap", args{ val: map[string]interface{}{"hello": "world"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1204,7 +1085,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMapWithNil", args{ val: map[string]interface{}{"hello": "world", "hi": nil, "hey": "kitex"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1218,7 +1099,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMapWithNilOnly", args{ val: map[string]interface{}{"hello": nil}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1232,7 +1113,7 @@ func Test_writeStringMap(t *testing.T) { "writeStringMapWithElemNextWriterError", args{ val: map[string]interface{}{"hello": "world"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.MAP, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1245,7 +1126,7 @@ func Test_writeStringMap(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeStringMap(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeStringMap(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeStringMap() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1255,7 +1136,7 @@ func Test_writeStringMap(t *testing.T) { func Test_writeStruct(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1269,7 +1150,7 @@ func Test_writeStruct(t *testing.T) { "writeStruct", args{ val: map[string]interface{}{"hello": "world"}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1291,7 +1172,7 @@ func Test_writeStruct(t *testing.T) { "writeStructRequired", args{ val: map[string]interface{}{"hello": nil}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1313,7 +1194,7 @@ func Test_writeStruct(t *testing.T) { "writeStructOptional", args{ val: map[string]interface{}{}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1332,7 +1213,7 @@ func Test_writeStruct(t *testing.T) { "writeStructError", args{ val: map[string]interface{}{"strList": []interface{}{int64(123)}}, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1350,7 +1231,7 @@ func Test_writeStruct(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeStruct(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeStruct(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeStruct() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1360,7 +1241,7 @@ func Test_writeStruct(t *testing.T) { func Test_writeHTTPRequest(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1376,7 +1257,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{"hello": "world"}, }, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1402,7 +1283,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{"hello": nil}, }, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1429,7 +1310,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{"hello": nil}, }, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1456,7 +1337,7 @@ func Test_writeHTTPRequest(t *testing.T) { val: &descriptor.HTTPRequest{ Body: map[string]interface{}{}, }, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1480,7 +1361,7 @@ func Test_writeHTTPRequest(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeHTTPRequest(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeHTTPRequest(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeHTTPRequest() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1490,7 +1371,7 @@ func Test_writeHTTPRequest(t *testing.T) { func Test_writeHTTPRequestWithPbBody(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1534,15 +1415,15 @@ func Test_writeHTTPRequestWithPbBody(t *testing.T) { "writeStructSuccess", args{ val: req, - out: thrift.NewBinaryWriter(), - t: typeDescriptor, + + t: typeDescriptor, }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeHTTPRequest(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeHTTPRequest(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeHTTPRequest() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1584,7 +1465,6 @@ func Test_writeRequestBase(t *testing.T) { type args struct { ctx context.Context val interface{} - out *thrift.BinaryWriter field *descriptor.FieldDescriptor opt *writerOption } @@ -1599,7 +1479,7 @@ func Test_writeRequestBase(t *testing.T) { "writeStruct", args{ val: map[string]interface{}{"Extra": map[string]interface{}{"hello": "world"}}, - out: thrift.NewBinaryWriter(), + field: &descriptor.FieldDescriptor{ Name: "base", ID: 255, @@ -1612,7 +1492,7 @@ func Test_writeRequestBase(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeRequestBase(tt.args.ctx, tt.args.val, tt.args.out, tt.args.field, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeRequestBase(tt.args.ctx, tt.args.val, getBufferWriter(nil), tt.args.field, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeRequestBase() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1622,7 +1502,7 @@ func Test_writeRequestBase(t *testing.T) { func Test_writeJSON(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1638,7 +1518,7 @@ func Test_writeJSON(t *testing.T) { "writeJSON", args{ val: &data, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1660,7 +1540,7 @@ func Test_writeJSON(t *testing.T) { "writeJSONRequired", args{ val: &dataEmpty, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1682,7 +1562,7 @@ func Test_writeJSON(t *testing.T) { "writeJSONOptional", args{ val: &dataEmpty, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1700,7 +1580,7 @@ func Test_writeJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeJSON(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeJSON(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeJSON() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1710,7 +1590,7 @@ func Test_writeJSON(t *testing.T) { func Test_writeJSONBase(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter + t *descriptor.TypeDescriptor opt *writerOption } @@ -1724,7 +1604,7 @@ func Test_writeJSONBase(t *testing.T) { "writeJSONBase", args{ val: &data, - out: thrift.NewBinaryWriter(), + t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1755,7 +1635,7 @@ func Test_writeJSONBase(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeJSON(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeJSON(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeJSON() error = %v, wantErr %v", err, tt.wantErr) } test.DeepEqual(t, tt.args.opt.requestBase.Extra, map[string]string{"hello": "world"}) @@ -1766,7 +1646,6 @@ func Test_writeJSONBase(t *testing.T) { func Test_getDefaultValueAndWriter(t *testing.T) { type args struct { val interface{} - out *thrift.BinaryWriter t *descriptor.TypeDescriptor opt *writerOption } @@ -1780,7 +1659,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "bool", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.BOOL}, @@ -1793,7 +1671,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i08", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I08}, @@ -1806,7 +1683,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i16", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I16}, @@ -1819,7 +1695,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i32", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I32}, @@ -1832,7 +1707,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "i64", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.I64}, @@ -1845,7 +1719,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "double", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.DOUBLE}, @@ -1858,7 +1731,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "stringBinary", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), opt: &writerOption{ binaryWithBase64: true, }, @@ -1877,7 +1749,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "stringNonBinary", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, @@ -1890,7 +1761,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "list", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -1906,7 +1776,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "set", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -1922,7 +1791,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "map", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -1939,7 +1807,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "struct", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -1965,7 +1832,6 @@ func Test_getDefaultValueAndWriter(t *testing.T) { "void", args{ val: []interface{}{nil}, - out: thrift.NewBinaryWriter(), t: &descriptor.TypeDescriptor{ Type: descriptor.LIST, Elem: &descriptor.TypeDescriptor{ @@ -1980,9 +1846,15 @@ func Test_getDefaultValueAndWriter(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := writeList(context.Background(), tt.args.val, tt.args.out, tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { + if err := writeList(context.Background(), tt.args.val, getBufferWriter(nil), tt.args.t, tt.args.opt); (err != nil) != tt.wantErr { t.Errorf("writeList() error = %v, wantErr %v", err, tt.wantErr) } }) } } + +func getBufferWriter(bs []byte) *thrift.BufferWriter { + bw := bufiox.NewBytesWriter(&bs) + w := thrift.NewBufferWriter(bw) + return w +} diff --git a/internal/mocks/conn.go b/internal/mocks/conn.go index 86b3f8940e..f4d68713f4 100644 --- a/internal/mocks/conn.go +++ b/internal/mocks/conn.go @@ -17,6 +17,7 @@ package mocks import ( + bytes2 "bytes" "net" "time" ) @@ -98,3 +99,15 @@ func (m Conn) SetWriteDeadline(t time.Time) (e error) { } return } + +func NewIOConn() *Conn { + var bytes bytes2.Buffer + return &Conn{ + ReadFunc: func(b []byte) (n int, err error) { + return bytes.Read(b) + }, + WriteFunc: func(b []byte) (n int, err error) { + return bytes.Write(b) + }, + } +} diff --git a/internal/mocks/generic/thrift.go b/internal/mocks/generic/thrift.go index ae32354f04..6fa6a38efd 100644 --- a/internal/mocks/generic/thrift.go +++ b/internal/mocks/generic/thrift.go @@ -16,17 +16,17 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ../../pkg/generic/thrift/thrift.go +// Source: ../../internal/generic/thrift/thrift.go // Package generic is a generated GoMock package. package generic import ( context "context" - "io" reflect "reflect" - "github.com/cloudwego/gopkg/protocol/thrift/base" + bufiox "github.com/cloudwego/gopkg/bufiox" + base "github.com/cloudwego/gopkg/protocol/thrift/base" gomock "github.com/golang/mock/gomock" ) @@ -54,18 +54,18 @@ func (m *MockMessageReader) EXPECT() *MockMessageReaderMockRecorder { } // Read mocks base method. -func (m *MockMessageReader) Read(ctx context.Context, method string, isClient bool, dataLen int, in io.Reader) (interface{}, error) { +func (m *MockMessageReader) Read(ctx context.Context, method string, isClient bool, dataLen int, in bufiox.Reader) (interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", ctx, method, in) + ret := m.ctrl.Call(m, "Read", ctx, method, isClient, dataLen, in) ret0, _ := ret[0].(interface{}) ret1, _ := ret[1].(error) return ret0, ret1 } // Read indicates an expected call of Read. -func (mr *MockMessageReaderMockRecorder) Read(ctx, method, in interface{}) *gomock.Call { +func (mr *MockMessageReaderMockRecorder) Read(ctx, method, isClient, dataLen, in interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockMessageReader)(nil).Read), ctx, method, in) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockMessageReader)(nil).Read), ctx, method, isClient, dataLen, in) } // MockMessageWriter is a mock of MessageWriter interface. @@ -92,15 +92,15 @@ func (m *MockMessageWriter) EXPECT() *MockMessageWriterMockRecorder { } // Write mocks base method. -func (m *MockMessageWriter) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { +func (m *MockMessageWriter) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", ctx, out, msg, requestBase) + ret := m.ctrl.Call(m, "Write", ctx, out, msg, method, isClient, requestBase) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockMessageWriterMockRecorder) Write(ctx, out, msg, requestBase interface{}) *gomock.Call { +func (mr *MockMessageWriterMockRecorder) Write(ctx, out, msg, method, isClient, requestBase interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockMessageWriter)(nil).Write), ctx, out, msg, requestBase) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockMessageWriter)(nil).Write), ctx, out, msg, method, isClient, requestBase) } diff --git a/internal/mocks/remote/bytebuf.go b/internal/mocks/remote/bytebuf.go index ab468b07b6..ee5017f312 100644 --- a/internal/mocks/remote/bytebuf.go +++ b/internal/mocks/remote/bytebuf.go @@ -90,6 +90,20 @@ func (m *MockNocopyWrite) EXPECT() *MockNocopyWriteMockRecorder { return m.recorder } +// MallocAck mocks base method. +func (m *MockNocopyWrite) MallocAck(n int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MallocAck", n) + ret0, _ := ret[0].(error) + return ret0 +} + +// MallocAck indicates an expected call of MallocAck. +func (mr *MockNocopyWriteMockRecorder) MallocAck(n interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MallocAck", reflect.TypeOf((*MockNocopyWrite)(nil).MallocAck), n) +} + // WriteDirect mocks base method. func (m *MockNocopyWrite) WriteDirect(buf []byte, remainCap int) error { m.ctrl.T.Helper() @@ -236,20 +250,6 @@ func (mr *MockByteBufferMockRecorder) Malloc(n interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Malloc", reflect.TypeOf((*MockByteBuffer)(nil).Malloc), n) } -// MallocLen mocks base method. -func (m *MockByteBuffer) MallocLen() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MallocLen") - ret0, _ := ret[0].(int) - return ret0 -} - -// MallocLen indicates an expected call of MallocLen. -func (mr *MockByteBufferMockRecorder) MallocLen() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MallocLen", reflect.TypeOf((*MockByteBuffer)(nil).MallocLen)) -} - // NewBuffer mocks base method. func (m *MockByteBuffer) NewBuffer() remote.ByteBuffer { m.ctrl.T.Helper() @@ -310,18 +310,18 @@ func (mr *MockByteBufferMockRecorder) Read(p interface{}) *gomock.Call { } // ReadBinary mocks base method. -func (m *MockByteBuffer) ReadBinary(n int) ([]byte, error) { +func (m *MockByteBuffer) ReadBinary(p []byte) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadBinary", n) - ret0, _ := ret[0].([]byte) + ret := m.ctrl.Call(m, "ReadBinary", p) + ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // ReadBinary indicates an expected call of ReadBinary. -func (mr *MockByteBufferMockRecorder) ReadBinary(n interface{}) *gomock.Call { +func (mr *MockByteBufferMockRecorder) ReadBinary(p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBinary", reflect.TypeOf((*MockByteBuffer)(nil).ReadBinary), n) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBinary", reflect.TypeOf((*MockByteBuffer)(nil).ReadBinary), p) } // ReadLen mocks base method. @@ -439,3 +439,17 @@ func (mr *MockByteBufferMockRecorder) WriteString(s interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockByteBuffer)(nil).WriteString), s) } + +// WrittenLen mocks base method. +func (m *MockByteBuffer) WrittenLen() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WrittenLen") + ret0, _ := ret[0].(int) + return ret0 +} + +// WrittenLen indicates an expected call of WrittenLen. +func (mr *MockByteBufferMockRecorder) WrittenLen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WrittenLen", reflect.TypeOf((*MockByteBuffer)(nil).WrittenLen)) +} diff --git a/internal/mocks/update.sh b/internal/mocks/update.sh index 68afa94c56..350fb4ad79 100755 --- a/internal/mocks/update.sh +++ b/internal/mocks/update.sh @@ -17,7 +17,7 @@ files=( ../../pkg/remote/payload_codec.go remote/payload_codec.go remote ../../pkg/generic/generic_service.go generic/generic_service.go generic ../../pkg/klog/log.go klog/log.go klog -../../pkg/generic/thrift/thrift.go generic/thrift.go generic +../../internal/generic/thrift/thrift.go generic/thrift.go generic ../../pkg/discovery/discovery.go discovery/discovery.go discovery ../../pkg/loadbalance/loadbalancer.go loadbalance/loadbalancer.go loadbalance ../../pkg/proxy/proxy.go proxy/proxy.go proxy diff --git a/pkg/generic/binarythrift_codec.go b/pkg/generic/binarythrift_codec.go index eeca8304d3..5bb0828de8 100644 --- a/pkg/generic/binarythrift_codec.go +++ b/pkg/generic/binarythrift_codec.go @@ -21,6 +21,8 @@ import ( "encoding/binary" "fmt" + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" @@ -74,8 +76,8 @@ func (c *binaryThriftCodec) Marshal(ctx context.Context, msg remote.Message, out return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("rawThriftBinaryCodec set seqID failed, err: %s", err.Error())) } } - out.WriteBinary(transBuff) - return nil + _, err := out.WriteBinary(transBuff) + return err } func (c *binaryThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { @@ -88,7 +90,8 @@ func (c *binaryThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, i return c.thriftCodec.Unmarshal(ctx, msg, in) } payloadLen := msg.PayloadLen() - transBuff, err := in.ReadBinary(payloadLen) + transBuff := dirtmake.Bytes(payloadLen, payloadLen) + _, err = in.ReadBinary(transBuff) if err != nil { return err } diff --git a/pkg/generic/binarythrift_codec_test.go b/pkg/generic/binarythrift_codec_test.go index 48fbda0fea..3faac6c211 100644 --- a/pkg/generic/binarythrift_codec_test.go +++ b/pkg/generic/binarythrift_codec_test.go @@ -20,8 +20,10 @@ import ( "context" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/internal/mocks" kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" @@ -56,14 +58,21 @@ func TestBinaryThriftCodec(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, seqID == 100, seqID) - rwbuf := remote.NewReaderWriterBuffer(1024) + conn := mocks.NewIOConn() + bw := bufiox.NewDefaultWriter(conn) + br := bufiox.NewDefaultReader(conn) + bb := remote.NewByteBufferFromBufiox(bw, br) + // change seqID to 1 - err = btc.Marshal(context.Background(), cliMsg, rwbuf) + err = btc.Marshal(context.Background(), cliMsg, bb) test.Assert(t, err == nil, err) seqID, err = GetSeqID(cliMsg.Data().(*Args).Request.(binaryReqType)) test.Assert(t, err == nil, err) test.Assert(t, seqID == 1, seqID) + wl := bw.WrittenLen() + bw.Flush() + // server side arg := &Args{} svrMsg := &mockMessage{ @@ -77,13 +86,13 @@ func TestBinaryThriftCodec(t *testing.T) { return arg }, PayloadLenFunc: func() int { - return rwbuf.ReadableLen() + return wl }, ServiceInfoFunc: func() *serviceinfo.ServiceInfo { return ServiceInfo(serviceinfo.Thrift) }, } - err = btc.Unmarshal(context.Background(), svrMsg, rwbuf) + err = btc.Unmarshal(context.Background(), svrMsg, bb) test.Assert(t, err == nil, err) reqBuf := svrMsg.Data().(*Args).Request.(binaryReqType) seqID, err = GetSeqID(reqBuf) @@ -112,24 +121,28 @@ func TestBinaryThriftCodecExceptionError(t *testing.T) { }, } - rwbuf := remote.NewReaderWriterBuffer(1024) + conn := mocks.NewIOConn() + bw := bufiox.NewDefaultWriter(conn) + br := bufiox.NewDefaultReader(conn) + bb := remote.NewByteBufferFromBufiox(bw, br) // test data is empty - err := btc.Marshal(ctx, cliMsg, rwbuf) + err := btc.Marshal(ctx, cliMsg, bb) test.Assert(t, err.Error() == "invalid marshal data in rawThriftBinaryCodec: nil") cliMsg.DataFunc = func() interface{} { return &remote.TransError{} } // empty method - err = btc.Marshal(ctx, cliMsg, rwbuf) + err = btc.Marshal(ctx, cliMsg, bb) test.Assert(t, err.Error() == "rawThriftBinaryCodec Marshal exception failed, err: empty methodName in thrift Marshal", err) cliMsg.RPCInfoFunc = func() rpcinfo.RPCInfo { return newMockRPCInfo() } - err = btc.Marshal(ctx, cliMsg, rwbuf) + err = btc.Marshal(ctx, cliMsg, bb) test.Assert(t, err == nil) - err = btc.Unmarshal(ctx, cliMsg, rwbuf) + bw.Flush() + err = btc.Unmarshal(ctx, cliMsg, bb) test.Assert(t, err.Error() == "unknown application exception") // test server role @@ -141,7 +154,7 @@ func TestBinaryThriftCodecExceptionError(t *testing.T) { Success: binaryReqType{}, } } - err = btc.Marshal(ctx, cliMsg, rwbuf) + err = btc.Marshal(ctx, cliMsg, bb) test.Assert(t, err == nil) } diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index dd6535550a..758c0d518b 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -19,8 +19,8 @@ package generic import ( "context" "fmt" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/base" "github.com/cloudwego/kitex/internal/generic/proto" @@ -143,7 +143,7 @@ func (g *Args) GetOrSetBase() interface{} { } // Write ... -func (g *Args) Write(ctx context.Context, method string, out io.Writer) error { +func (g *Args) Write(ctx context.Context, method string, out bufiox.Writer) error { if err, ok := g.inner.(error); ok { return err } @@ -164,7 +164,7 @@ func (g *Args) WritePb(ctx context.Context, method string) (interface{}, error) } // Read ... -func (g *Args) Read(ctx context.Context, method string, dataLen int, in io.Reader) error { +func (g *Args) Read(ctx context.Context, method string, dataLen int, in bufiox.Reader) error { if err, ok := g.inner.(error); ok { return err } @@ -213,7 +213,7 @@ func (r *Result) SetCodec(inner interface{}) { } // Write ... -func (r *Result) Write(ctx context.Context, method string, out io.Writer) error { +func (r *Result) Write(ctx context.Context, method string, out bufiox.Writer) error { if err, ok := r.inner.(error); ok { return err } @@ -234,7 +234,7 @@ func (r *Result) WritePb(ctx context.Context, method string) (interface{}, error } // Read ... -func (r *Result) Read(ctx context.Context, method string, dataLen int, in io.Reader) error { +func (r *Result) Read(ctx context.Context, method string, dataLen int, in bufiox.Reader) error { if err, ok := r.inner.(error); ok { return err } diff --git a/pkg/generic/generic_service_test.go b/pkg/generic/generic_service_test.go index cc9bdb69d6..a30ab214cc 100644 --- a/pkg/generic/generic_service_test.go +++ b/pkg/generic/generic_service_test.go @@ -22,12 +22,12 @@ import ( "strings" "testing" - gbase "github.com/cloudwego/gopkg/protocol/thrift/base" + "github.com/cloudwego/gopkg/bufiox" "github.com/golang/mock/gomock" + mocksn "github.com/cloudwego/kitex/internal/mocks" mocks "github.com/cloudwego/kitex/internal/mocks/generic" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/utils" ) @@ -35,7 +35,11 @@ import ( func TestGenericService(t *testing.T) { ctx := context.Background() method := "test" - buffer := remote.NewReaderWriterBuffer(256) + + conn := mocksn.NewIOConn() + + wb := bufiox.NewDefaultWriter(conn) + rb := bufiox.NewDefaultReader(conn) ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -43,7 +47,7 @@ func TestGenericService(t *testing.T) { argWriteInner, resultWriteInner := mocks.NewMockMessageWriter(ctrl), mocks.NewMockMessageWriter(ctrl) rInner := mocks.NewMockMessageReader(ctrl) // Read expect - rInner.EXPECT().Read(ctx, method, buffer).Return("test", nil).AnyTimes() + rInner.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return("test", nil).AnyTimes() // Args... arg := newGenericServiceCallArgs() @@ -54,21 +58,21 @@ func TestGenericService(t *testing.T) { test.Assert(t, base != nil) a.SetCodec(struct{}{}) // write not ok - err := a.Write(ctx, method, buffer) + err := a.Write(ctx, method, wb) test.Assert(t, err.Error() == "unexpected Args writer type: struct {}") // Write expect - argWriteInner.EXPECT().Write(ctx, buffer, a.Request, a.GetOrSetBase()).Return(nil) + argWriteInner.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) a.SetCodec(argWriteInner) // write ok - err = a.Write(ctx, method, buffer) + err = a.Write(ctx, method, wb) test.Assert(t, err == nil, err) // read not ok - err = a.Read(ctx, method, 0, buffer) + err = a.Read(ctx, method, 0, rb) test.Assert(t, strings.Contains(err.Error(), "unexpected Args reader type")) // read ok a.SetCodec(rInner) - err = a.Read(ctx, method, 0, buffer) + err = a.Read(ctx, method, 0, rb) test.Assert(t, err == nil, err) // Result... @@ -77,20 +81,20 @@ func TestGenericService(t *testing.T) { test.Assert(t, ok == true) // write not ok - err = r.Write(ctx, method, buffer) + err = r.Write(ctx, method, wb) test.Assert(t, err.Error() == "unexpected Result writer type: ") // Write expect - resultWriteInner.EXPECT().Write(ctx, buffer, r.Success, (*gbase.Base)(nil)).Return(nil).AnyTimes() + resultWriteInner.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() r.SetCodec(resultWriteInner) // write ok - err = r.Write(ctx, method, buffer) + err = r.Write(ctx, method, wb) test.Assert(t, err == nil) // read not ok - err = r.Read(ctx, method, 0, buffer) + err = r.Read(ctx, method, 0, rb) test.Assert(t, strings.Contains(err.Error(), "unexpected Result reader type")) // read ok r.SetCodec(rInner) - err = r.Read(ctx, method, 0, buffer) + err = r.Read(ctx, method, 0, rb) test.Assert(t, err == nil) r.SetSuccess(nil) diff --git a/pkg/protocol/bthrift/apache/apache.go b/pkg/protocol/bthrift/apache/apache.go index d227f95858..4c076e9bc5 100644 --- a/pkg/protocol/bthrift/apache/apache.go +++ b/pkg/protocol/bthrift/apache/apache.go @@ -18,9 +18,9 @@ package apache import ( "errors" - "io" "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/apache" ) @@ -41,30 +41,24 @@ func checkTStruct(v interface{}) error { return nil } -func callThriftRead(r io.ReadWriter, v interface{}) error { +func callThriftRead(r bufiox.Reader, v interface{}) error { p, ok := v.(thrift.TStruct) if !ok { return errNotThriftTStruct } - t, ok := r.(byteBuffer) - if ok { - in := NewBinaryProtocol(t) - return p.Read(in) - } - in := thrift.NewTBinaryProtocol(apache.NewDefaultTransport(r), true, true) - return p.Read(in) + bp := NewBinaryProtocol(r, nil) + err := p.Read(bp) + bp.Recycle() + return err } -func callThriftWrite(w io.ReadWriter, v interface{}) error { +func callThriftWrite(w bufiox.Writer, v interface{}) error { p, ok := v.(thrift.TStruct) if !ok { return errNotThriftTStruct } - t, ok := w.(byteBuffer) - if ok { - out := NewBinaryProtocol(t) - return p.Write(out) - } - out := thrift.NewTBinaryProtocol(apache.NewDefaultTransport(w), true, true) - return p.Write(out) + bp := NewBinaryProtocol(nil, w) + err := p.Write(bp) + bp.Recycle() + return err } diff --git a/pkg/protocol/bthrift/apache/binary_protocol.go b/pkg/protocol/bthrift/apache/binary_protocol.go index 8a150c0dc4..db97e868a4 100644 --- a/pkg/protocol/bthrift/apache/binary_protocol.go +++ b/pkg/protocol/bthrift/apache/binary_protocol.go @@ -18,11 +18,11 @@ package apache import ( "context" - "encoding/binary" - "io" - "math" "sync" + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" ) @@ -40,56 +40,39 @@ var ( } ) -// byteBuffer is sub interfaces of remote.ByteBuffer -// the repeated definition here is to avoid dependency on remote packages -type byteBuffer interface { - io.ReadWriter - - // WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte. - WriteString(s string) (n int, err error) - - // WriteBinary writes the []byte directly. Callers must guarantee that the []byte doesn't change. - WriteBinary(b []byte) (n int, err error) - - // Malloc n bytes sequentially in the writer buffer. - Malloc(n int) (buf []byte, err error) - - // Next reads the next n bytes sequentially and returns the original buffer. - Next(n int) (p []byte, err error) - - // ReadString is a more efficient way to read string than Next. - ReadString(n int) (s string, err error) - - // ReadBinary like ReadString. - // Returns a copy of original buffer. - ReadBinary(n int) (p []byte, err error) - - // ReadableLen returns the total length of readable buffer. - // Return: -1 means unreadable. - ReadableLen() (n int) - - // Flush writes any malloc data to the underlying io.Writer. - // The malloced buffer must be set correctly. - Flush() (err error) -} - // BinaryProtocol was moved from cloudwego/kitex/pkg/remote/codec/thrift -// Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol type BinaryProtocol struct { - trans byteBuffer + r *thrift.BufferReader + w *thrift.BufferWriter + + br bufiox.Reader + bw bufiox.Writer } // NewBinaryProtocol ... -// Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol -func NewBinaryProtocol(t byteBuffer) *BinaryProtocol { +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift.NewBufferReader|NewBufferWriter +func NewBinaryProtocol(r bufiox.Reader, w bufiox.Writer) *BinaryProtocol { bp := bpPool.Get().(*BinaryProtocol) - bp.trans = t + if r != nil { + bp.r = thrift.NewBufferReader(r) + bp.br = r + } + if w != nil { + bp.w = thrift.NewBufferWriter(w) + bp.bw = w + } return bp } // Recycle ... func (p *BinaryProtocol) Recycle() { - p.trans = nil + if p.r != nil { + p.r.Recycle() + } + if p.w != nil { + p.w.Recycle() + } + *p = BinaryProtocol{} bpPool.Put(p) } @@ -99,17 +82,7 @@ func (p *BinaryProtocol) Recycle() { // WriteMessageBegin ... func (p *BinaryProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error { - version := uint32(VERSION_1) | uint32(typeID) - e := p.WriteI32(int32(version)) - if e != nil { - return e - } - e = p.WriteString(name) - if e != nil { - return e - } - e = p.WriteI32(seqID) - return e + return p.w.WriteMessageBegin(name, thrift.TMessageType(typeID), seqID) } // WriteMessageEnd ... @@ -129,12 +102,7 @@ func (p *BinaryProtocol) WriteStructEnd() error { // WriteFieldBegin ... func (p *BinaryProtocol) WriteFieldBegin(name string, typeID TType, id int16) error { - e := p.WriteByte(int8(typeID)) - if e != nil { - return e - } - e = p.WriteI16(id) - return e + return p.w.WriteFieldBegin(thrift.TType(typeID), id) } // WriteFieldEnd ... @@ -144,22 +112,12 @@ func (p *BinaryProtocol) WriteFieldEnd() error { // WriteFieldStop ... func (p *BinaryProtocol) WriteFieldStop() error { - e := p.WriteByte(STOP) - return e + return p.w.WriteFieldStop() } // WriteMapBegin ... func (p *BinaryProtocol) WriteMapBegin(keyType, valueType TType, size int) error { - e := p.WriteByte(int8(keyType)) - if e != nil { - return e - } - e = p.WriteByte(int8(valueType)) - if e != nil { - return e - } - e = p.WriteI32(int32(size)) - return e + return p.w.WriteMapBegin(thrift.TType(keyType), thrift.TType(valueType), size) } // WriteMapEnd ... @@ -169,12 +127,7 @@ func (p *BinaryProtocol) WriteMapEnd() error { // WriteListBegin ... func (p *BinaryProtocol) WriteListBegin(elemType TType, size int) error { - e := p.WriteByte(int8(elemType)) - if e != nil { - return e - } - e = p.WriteI32(int32(size)) - return e + return p.w.WriteListBegin(thrift.TType(elemType), size) } // WriteListEnd ... @@ -184,12 +137,7 @@ func (p *BinaryProtocol) WriteListEnd() error { // WriteSetBegin ... func (p *BinaryProtocol) WriteSetBegin(elemType TType, size int) error { - e := p.WriteByte(int8(elemType)) - if e != nil { - return e - } - e = p.WriteI32(int32(size)) - return e + return p.w.WriteSetBegin(thrift.TType(elemType), size) } // WriteSetEnd ... @@ -199,85 +147,42 @@ func (p *BinaryProtocol) WriteSetEnd() error { // WriteBool ... func (p *BinaryProtocol) WriteBool(value bool) error { - if value { - return p.WriteByte(1) - } - return p.WriteByte(0) + return p.w.WriteBool(value) } // WriteByte ... func (p *BinaryProtocol) WriteByte(value int8) error { - v, err := p.malloc(1) - if err != nil { - return err - } - v[0] = byte(value) - return err + return p.w.WriteByte(value) } // WriteI16 ... func (p *BinaryProtocol) WriteI16(value int16) error { - v, err := p.malloc(2) - if err != nil { - return err - } - binary.BigEndian.PutUint16(v, uint16(value)) - return err + return p.w.WriteI16(value) } // WriteI32 ... func (p *BinaryProtocol) WriteI32(value int32) error { - v, err := p.malloc(4) - if err != nil { - return err - } - binary.BigEndian.PutUint32(v, uint32(value)) - return err + return p.w.WriteI32(value) } // WriteI64 ... func (p *BinaryProtocol) WriteI64(value int64) error { - v, err := p.malloc(8) - if err != nil { - return err - } - binary.BigEndian.PutUint64(v, uint64(value)) - return err + return p.w.WriteI64(value) } // WriteDouble ... func (p *BinaryProtocol) WriteDouble(value float64) error { - return p.WriteI64(int64(math.Float64bits(value))) + return p.w.WriteDouble(value) } // WriteString ... func (p *BinaryProtocol) WriteString(value string) error { - len := len(value) - e := p.WriteI32(int32(len)) - if e != nil { - return e - } - _, e = p.trans.WriteString(value) - return e + return p.w.WriteString(value) } // WriteBinary ... func (p *BinaryProtocol) WriteBinary(value []byte) error { - e := p.WriteI32(int32(len(value))) - if e != nil { - return e - } - _, e = p.trans.WriteBinary(value) - return e -} - -// malloc ... -func (p *BinaryProtocol) malloc(size int) ([]byte, error) { - buf, err := p.trans.Malloc(size) - if err != nil { - return buf, perrors.NewProtocolError(err) - } - return buf, nil + return p.w.WriteBinary(value) } /** @@ -286,27 +191,10 @@ func (p *BinaryProtocol) malloc(size int) ([]byte, error) { // ReadMessageBegin ... func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) { - size, e := p.ReadI32() - if e != nil { - return "", typeID, 0, perrors.NewProtocolError(e) - } - if size > 0 { - return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin") - } - typeID = TMessageType(size & 0x0ff) - version := int64(int64(size) & VERSION_MASK) - if version != VERSION_1 { - return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin") - } - name, e = p.ReadString() - if e != nil { - return name, typeID, seqID, perrors.NewProtocolError(e) - } - seqID, e = p.ReadI32() - if e != nil { - return name, typeID, seqID, perrors.NewProtocolError(e) - } - return name, typeID, seqID, nil + var tid thrift.TMessageType + name, tid, seqID, err = p.r.ReadMessageBegin() + typeID = TMessageType(tid) + return } // ReadMessageEnd ... @@ -326,15 +214,10 @@ func (p *BinaryProtocol) ReadStructEnd() error { // ReadFieldBegin ... func (p *BinaryProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) { - t, err := p.ReadByte() - typeID = TType(t) - if err != nil { - return name, typeID, id, err - } - if t != STOP { - id, err = p.ReadI16() - } - return name, typeID, id, err + var tid thrift.TType + tid, id, err = p.r.ReadFieldBegin() + typeID = TType(tid) + return } // ReadFieldEnd ... @@ -344,29 +227,11 @@ func (p *BinaryProtocol) ReadFieldEnd() error { // ReadMapBegin ... func (p *BinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { - k, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - kType = TType(k) - v, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - vType = TType(v) - size32, e := p.ReadI32() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - if size32 < 0 { - err = perrors.InvalidDataLength - return - } - size = int(size32) - return kType, vType, size, nil + var ktype, vtype thrift.TType + ktype, vtype, size, err = p.r.ReadMapBegin() + kType = TType(ktype) + vType = TType(vtype) + return } // ReadMapEnd ... @@ -376,23 +241,9 @@ func (p *BinaryProtocol) ReadMapEnd() error { // ReadListBegin ... func (p *BinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { - b, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - elemType = TType(b) - size32, e := p.ReadI32() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - if size32 < 0 { - err = perrors.InvalidDataLength - return - } - size = int(size32) - + var etype thrift.TType + etype, size, err = p.r.ReadListBegin() + elemType = TType(etype) return } @@ -403,23 +254,10 @@ func (p *BinaryProtocol) ReadListEnd() error { // ReadSetBegin ... func (p *BinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { - b, e := p.ReadByte() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - elemType = TType(b) - size32, e := p.ReadI32() - if e != nil { - err = perrors.NewProtocolError(e) - return - } - if size32 < 0 { - err = perrors.InvalidDataLength - return - } - size = int(size32) - return elemType, size, nil + var etype thrift.TType + etype, size, err = p.r.ReadSetBegin() + elemType = TType(etype) + return } // ReadSetEnd ... @@ -429,95 +267,47 @@ func (p *BinaryProtocol) ReadSetEnd() error { // ReadBool ... func (p *BinaryProtocol) ReadBool() (bool, error) { - b, e := p.ReadByte() - v := true - if b != 1 { - v = false - } - return v, e + return p.r.ReadBool() } // ReadByte ... func (p *BinaryProtocol) ReadByte() (value int8, err error) { - buf, err := p.next(1) - if err != nil { - return value, err - } - return int8(buf[0]), err + return p.r.ReadByte() } // ReadI16 ... func (p *BinaryProtocol) ReadI16() (value int16, err error) { - buf, err := p.next(2) - if err != nil { - return value, err - } - value = int16(binary.BigEndian.Uint16(buf)) - return value, err + return p.r.ReadI16() } // ReadI32 ... func (p *BinaryProtocol) ReadI32() (value int32, err error) { - buf, err := p.next(4) - if err != nil { - return value, err - } - value = int32(binary.BigEndian.Uint32(buf)) - return value, err + return p.r.ReadI32() } // ReadI64 ... func (p *BinaryProtocol) ReadI64() (value int64, err error) { - buf, err := p.next(8) - if err != nil { - return value, err - } - value = int64(binary.BigEndian.Uint64(buf)) - return value, err + return p.r.ReadI64() } // ReadDouble ... func (p *BinaryProtocol) ReadDouble() (value float64, err error) { - buf, err := p.next(8) - if err != nil { - return value, err - } - value = math.Float64frombits(binary.BigEndian.Uint64(buf)) - return value, err + return p.r.ReadDouble() } // ReadString ... func (p *BinaryProtocol) ReadString() (value string, err error) { - size, e := p.ReadI32() - if e != nil { - return "", e - } - if size < 0 { - err = perrors.InvalidDataLength - return - } - value, err = p.trans.ReadString(int(size)) - if err != nil { - return value, perrors.NewProtocolError(err) - } - return value, nil + return p.r.ReadString() } // ReadBinary ... func (p *BinaryProtocol) ReadBinary() ([]byte, error) { - size, e := p.ReadI32() - if e != nil { - return nil, e - } - if size < 0 { - return nil, perrors.InvalidDataLength - } - return p.trans.ReadBinary(int(size)) + return p.r.ReadBinary() } // Flush ... func (p *BinaryProtocol) Flush(ctx context.Context) (err error) { - err = p.trans.Flush() + err = p.bw.Flush() if err != nil { return perrors.NewProtocolError(err) } @@ -531,32 +321,18 @@ func (p *BinaryProtocol) Skip(fieldType TType) (err error) { // Transport ... func (p *BinaryProtocol) Transport() TTransport { - return ttransportByteBuffer{p.trans} -} - -// ByteBuffer ... -func (p *BinaryProtocol) ByteBuffer() byteBuffer { - return p.trans -} - -// next ... -func (p *BinaryProtocol) next(size int) ([]byte, error) { - buf, err := p.trans.Next(size) - if err != nil { - return buf, perrors.NewProtocolError(err) - } - return buf, nil + return ttransportByteBuffer{} } // ttransportByteBuffer ... // for exposing remote.ByteBuffer via p.Transport(), // mainly for testing purpose, see internal/mocks/athrift/utils.go -type ttransportByteBuffer struct { - byteBuffer -} +type ttransportByteBuffer struct{} func (ttransportByteBuffer) Close() error { panic("not implemented") } func (ttransportByteBuffer) Flush(ctx context.Context) (err error) { panic("not implemented") } func (ttransportByteBuffer) IsOpen() bool { panic("not implemented") } func (ttransportByteBuffer) Open() error { panic("not implemented") } -func (p ttransportByteBuffer) RemainingBytes() uint64 { return uint64(p.ReadableLen()) } +func (p ttransportByteBuffer) RemainingBytes() uint64 { panic("not implemented") } +func (ttransportByteBuffer) Read(p []byte) (n int, err error) { panic("not implemented") } +func (ttransportByteBuffer) Write(p []byte) (n int, err error) { panic("not implemented") } diff --git a/pkg/protocol/bthrift/binary.go b/pkg/protocol/bthrift/binary.go index 4f053d1378..260a42bc13 100644 --- a/pkg/protocol/bthrift/binary.go +++ b/pkg/protocol/bthrift/binary.go @@ -18,40 +18,28 @@ package bthrift import ( - "encoding/binary" - "fmt" - "math" - - "github.com/bytedance/gopkg/lang/span" gthrift "github.com/cloudwego/gopkg/protocol/thrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/utils" ) var ( // Binary protocol for bthrift. - Binary binaryProtocol - _ BTProtocol = binaryProtocol{} - spanCache = span.NewSpanCache(1024 * 1024) - spanCacheEnable bool = false + Binary binaryProtocol + _ BTProtocol = binaryProtocol{} ) -const binaryInplaceThreshold = 4096 // 4k - type binaryProtocol struct{} // SetSpanCache enable/disable binary protocol bytes/string allocator +// Deprecated: use github.com/cloudwego/gopkg/protocol/thrift.SetSpanCache func SetSpanCache(enable bool) { - spanCacheEnable = enable + gthrift.SetSpanCache(enable) } func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID thrift.TMessageType, seqid int32) int { - offset := 0 - version := uint32(thrift.VERSION_1) | uint32(typeID) - offset += Binary.WriteI32(buf, int32(version)) - offset += Binary.WriteString(buf[offset:], name) - offset += Binary.WriteI32(buf[offset:], seqid) - return offset + return gthrift.Binary.WriteMessageBegin(buf, name, gthrift.TMessageType(typeID), seqid) } func (binaryProtocol) WriteMessageEnd(buf []byte) int { @@ -67,7 +55,7 @@ func (binaryProtocol) WriteStructEnd(buf []byte) int { } func (binaryProtocol) WriteFieldBegin(buf []byte, name string, typeID thrift.TType, id int16) int { - return Binary.WriteByte(buf, int8(typeID)) + Binary.WriteI16(buf[1:], id) + return gthrift.Binary.WriteFieldBegin(buf, gthrift.TType(typeID), id) } func (binaryProtocol) WriteFieldEnd(buf []byte) int { @@ -75,13 +63,11 @@ func (binaryProtocol) WriteFieldEnd(buf []byte) int { } func (binaryProtocol) WriteFieldStop(buf []byte) int { - return Binary.WriteByte(buf, thrift.STOP) + return gthrift.Binary.WriteFieldStop(buf) } func (binaryProtocol) WriteMapBegin(buf []byte, keyType, valueType thrift.TType, size int) int { - return Binary.WriteByte(buf, int8(keyType)) + - Binary.WriteByte(buf[1:], int8(valueType)) + - Binary.WriteI32(buf[2:], int32(size)) + return gthrift.Binary.WriteMapBegin(buf, gthrift.TType(keyType), gthrift.TType(valueType), size) } func (binaryProtocol) WriteMapEnd(buf []byte) int { @@ -89,8 +75,7 @@ func (binaryProtocol) WriteMapEnd(buf []byte) int { } func (binaryProtocol) WriteListBegin(buf []byte, elemType thrift.TType, size int) int { - return Binary.WriteByte(buf, int8(elemType)) + - Binary.WriteI32(buf[1:], int32(size)) + return gthrift.Binary.WriteListBegin(buf, gthrift.TType(elemType), size) } func (binaryProtocol) WriteListEnd(buf []byte) int { @@ -98,8 +83,7 @@ func (binaryProtocol) WriteListEnd(buf []byte) int { } func (binaryProtocol) WriteSetBegin(buf []byte, elemType thrift.TType, size int) int { - return Binary.WriteByte(buf, int8(elemType)) + - Binary.WriteI32(buf[1:], int32(size)) + return gthrift.Binary.WriteSetBegin(buf, gthrift.TType(elemType), size) } func (binaryProtocol) WriteSetEnd(buf []byte) int { @@ -107,64 +91,49 @@ func (binaryProtocol) WriteSetEnd(buf []byte) int { } func (binaryProtocol) WriteBool(buf []byte, value bool) int { - if value { - return Binary.WriteByte(buf, 1) - } - return Binary.WriteByte(buf, 0) + return gthrift.Binary.WriteBool(buf, value) } func (binaryProtocol) WriteByte(buf []byte, value int8) int { - buf[0] = byte(value) - return 1 + return gthrift.Binary.WriteByte(buf, int8(value)) } func (binaryProtocol) WriteI16(buf []byte, value int16) int { - binary.BigEndian.PutUint16(buf, uint16(value)) - return 2 + return gthrift.Binary.WriteI16(buf, value) } func (binaryProtocol) WriteI32(buf []byte, value int32) int { - binary.BigEndian.PutUint32(buf, uint32(value)) - return 4 + return gthrift.Binary.WriteI32(buf, value) } func (binaryProtocol) WriteI64(buf []byte, value int64) int { - binary.BigEndian.PutUint64(buf, uint64(value)) - return 8 + return gthrift.Binary.WriteI64(buf, value) } func (binaryProtocol) WriteDouble(buf []byte, value float64) int { - return Binary.WriteI64(buf, int64(math.Float64bits(value))) + return gthrift.Binary.WriteDouble(buf, value) } func (binaryProtocol) WriteString(buf []byte, value string) int { - l := Binary.WriteI32(buf, int32(len(value))) - copy(buf[l:], value) - return l + len(value) + return gthrift.Binary.WriteString(buf, value) } func (binaryProtocol) WriteBinary(buf, value []byte) int { - l := Binary.WriteI32(buf, int32(len(value))) - copy(buf[l:], value) - return l + len(value) + return gthrift.Binary.WriteBinary(buf, value) } func (binaryProtocol) WriteStringNocopy(buf []byte, binaryWriter BinaryWriter, value string) int { - return Binary.WriteBinaryNocopy(buf, binaryWriter, stringToSliceByte(value)) + // can not inline + return gthrift.Binary.WriteBinaryNocopy(buf, binaryWriter, utils.StringToSliceByte(value)) } func (binaryProtocol) WriteBinaryNocopy(buf []byte, binaryWriter BinaryWriter, value []byte) int { - l := Binary.WriteI32(buf, int32(len(value))) - if binaryWriter != nil && len(value) > binaryInplaceThreshold { - binaryWriter.WriteDirect(value, len(buf[l:])) - return l - } - copy(buf[l:], value) - return l + len(value) + // can not inline + return gthrift.Binary.WriteBinaryNocopy(buf, binaryWriter, value) } func (binaryProtocol) MessageBeginLength(name string, _ thrift.TMessageType, _ int32) int { - return 4 + Binary.StringLength(name) + 4 + return gthrift.Binary.MessageBeginLength(name) } func (binaryProtocol) MessageEndLength() int { @@ -180,7 +149,7 @@ func (binaryProtocol) StructEndLength() int { } func (binaryProtocol) FieldBeginLength(name string, typeID thrift.TType, id int16) int { - return Binary.ByteLength(int8(typeID)) + Binary.I16Length(id) + return gthrift.Binary.FieldBeginLength() } func (binaryProtocol) FieldEndLength() int { @@ -188,13 +157,11 @@ func (binaryProtocol) FieldEndLength() int { } func (binaryProtocol) FieldStopLength() int { - return Binary.ByteLength(thrift.STOP) + return gthrift.Binary.FieldStopLength() } func (binaryProtocol) MapBeginLength(keyType, valueType thrift.TType, size int) int { - return Binary.ByteLength(int8(keyType)) + - Binary.ByteLength(int8(valueType)) + - Binary.I32Length(int32(size)) + return gthrift.Binary.MapBeginLength() } func (binaryProtocol) MapEndLength() int { @@ -202,8 +169,7 @@ func (binaryProtocol) MapEndLength() int { } func (binaryProtocol) ListBeginLength(elemType thrift.TType, size int) int { - return Binary.ByteLength(int8(elemType)) + - Binary.I32Length(int32(size)) + return gthrift.Binary.ListBeginLength() } func (binaryProtocol) ListEndLength() int { @@ -211,8 +177,7 @@ func (binaryProtocol) ListEndLength() int { } func (binaryProtocol) SetBeginLength(elemType thrift.TType, size int) int { - return Binary.ByteLength(int8(elemType)) + - Binary.I32Length(int32(size)) + return gthrift.Binary.SetBeginLength() } func (binaryProtocol) SetEndLength() int { @@ -220,85 +185,50 @@ func (binaryProtocol) SetEndLength() int { } func (binaryProtocol) BoolLength(value bool) int { - if value { - return Binary.ByteLength(1) - } - return Binary.ByteLength(0) + return gthrift.Binary.BoolLength() } func (binaryProtocol) ByteLength(value int8) int { - return 1 + return gthrift.Binary.ByteLength() } func (binaryProtocol) I16Length(value int16) int { - return 2 + return gthrift.Binary.I16Length() } func (binaryProtocol) I32Length(value int32) int { - return 4 + return gthrift.Binary.I32Length() } func (binaryProtocol) I64Length(value int64) int { - return 8 + return gthrift.Binary.I64Length() } func (binaryProtocol) DoubleLength(value float64) int { - return Binary.I64Length(int64(math.Float64bits(value))) + return gthrift.Binary.DoubleLength() } func (binaryProtocol) StringLength(value string) int { - return Binary.I32Length(int32(len(value))) + len(value) + return gthrift.Binary.StringLength(value) } func (binaryProtocol) BinaryLength(value []byte) int { - return Binary.I32Length(int32(len(value))) + len(value) + return gthrift.Binary.BinaryLength(value) } func (binaryProtocol) StringLengthNocopy(value string) int { - return Binary.BinaryLengthNocopy(stringToSliceByte(value)) + return gthrift.Binary.StringLengthNocopy(value) } func (binaryProtocol) BinaryLengthNocopy(value []byte) int { - l := Binary.I32Length(int32(len(value))) - return l + len(value) + return gthrift.Binary.BinaryLengthNocopy(value) } -var ( - errBadVersion = gthrift.NewProtocolException(gthrift.BAD_VERSION, "Bad version in ReadMessageBegin") - errMissingVersion = gthrift.NewProtocolException(gthrift.BAD_VERSION, "Missing version in ReadMessageBegin") - - errInvalidDataLength = gthrift.NewProtocolException(gthrift.INVALID_DATA, "Invalid data length") -) - func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID thrift.TMessageType, seqid int32, length int, err error) { - size, l, e := Binary.ReadI32(buf) - length += l - if e != nil { - err = e - return - } - if size > 0 { - err = errMissingVersion - return - } - typeID = thrift.TMessageType(size & 0x0ff) - version := int64(size) & thrift.VERSION_MASK - if version != thrift.VERSION_1 { - err = errBadVersion - return - } - name, l, e = Binary.ReadString(buf[length:]) - length += l - if e != nil { - err = e - return - } - seqid, l, e = Binary.ReadI32(buf[length:]) - length += l - if e != nil { - err = e - return - } + var tid gthrift.TMessageType + // can not inline + name, tid, seqid, length, err = gthrift.Binary.ReadMessageBegin(buf) + typeID = thrift.TMessageType(tid) return } @@ -311,357 +241,87 @@ func (binaryProtocol) ReadStructBegin(_ []byte) (name string, length int, err er func (binaryProtocol) ReadStructEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadFieldBegin(buf []byte) (name string, typeID thrift.TType, id int16, length int, err error) { - t, l, e := Binary.ReadByte(buf) - length += l - typeID = thrift.TType(t) - if e != nil { - err = e - return - } - if t != thrift.STOP { - id, l, err = Binary.ReadI16(buf[length:]) - length += l - } + var tid gthrift.TType + tid, id, length, err = gthrift.Binary.ReadFieldBegin(buf) + typeID = thrift.TType(tid) return } func (binaryProtocol) ReadFieldEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadMapBegin(buf []byte) (keyType, valueType thrift.TType, size, length int, err error) { - k, l, e := Binary.ReadByte(buf) - length += l - if e != nil { - err = e - return - } - keyType = thrift.TType(k) - v, l, e := Binary.ReadByte(buf[length:]) - length += l - if e != nil { - err = e - return - } - valueType = thrift.TType(v) - size32, l, e := Binary.ReadI32(buf[length:]) - length += l - if e != nil { - err = e - return - } - if size32 < 0 { - err = errInvalidDataLength - return - } - size = int(size32) + var ktid, vtid gthrift.TType + ktid, vtid, size, length, err = gthrift.Binary.ReadMapBegin(buf) + keyType = thrift.TType(ktid) + valueType = thrift.TType(vtid) return } func (binaryProtocol) ReadMapEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadListBegin(buf []byte) (elemType thrift.TType, size, length int, err error) { - b, l, e := Binary.ReadByte(buf) - length += l - if e != nil { - err = e - return - } - elemType = thrift.TType(b) - size32, l, e := Binary.ReadI32(buf[length:]) - length += l - if e != nil { - err = e - return - } - if size32 < 0 { - err = errInvalidDataLength - return - } - size = int(size32) - + var tid gthrift.TType + tid, size, length, err = gthrift.Binary.ReadListBegin(buf) + elemType = thrift.TType(tid) return } func (binaryProtocol) ReadListEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadSetBegin(buf []byte) (elemType thrift.TType, size, length int, err error) { - b, l, e := Binary.ReadByte(buf) - length += l - if e != nil { - err = e - return - } - elemType = thrift.TType(b) - size32, l, e := Binary.ReadI32(buf[length:]) - length += l - if e != nil { - err = e - return - } - if size32 < 0 { - err = errInvalidDataLength - return - } - size = int(size32) + var tid gthrift.TType + tid, size, length, err = gthrift.Binary.ReadSetBegin(buf) + elemType = thrift.TType(tid) return } func (binaryProtocol) ReadSetEnd(_ []byte) (int, error) { return 0, nil } func (binaryProtocol) ReadBool(buf []byte) (value bool, length int, err error) { - b, l, e := Binary.ReadByte(buf) - v := true - if b != 1 { - v = false - } - return v, l, e + value, length, err = gthrift.Binary.ReadBool(buf) + return } -var errReadByte = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadByte] len(buf) < 1") - func (binaryProtocol) ReadByte(buf []byte) (value int8, length int, err error) { - if len(buf) < 1 { - return value, length, errReadByte - } - return int8(buf[0]), 1, err + value, length, err = gthrift.Binary.ReadByte(buf) + return } -var errReadI16 = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadI16] len(buf) < 2") - func (binaryProtocol) ReadI16(buf []byte) (value int16, length int, err error) { - if len(buf) < 2 { - return value, length, errReadI16 - } - value = int16(binary.BigEndian.Uint16(buf)) - return value, 2, nil + value, length, err = gthrift.Binary.ReadI16(buf) + return } -var errReadI32 = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadI32] len(buf) < 4") - func (binaryProtocol) ReadI32(buf []byte) (value int32, length int, err error) { - if len(buf) < 4 { - return value, length, errReadI32 - } - value = int32(binary.BigEndian.Uint32(buf)) - return value, 4, nil + value, length, err = gthrift.Binary.ReadI32(buf) + return } -var errReadI64 = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadI64] len(buf) < 8") - func (binaryProtocol) ReadI64(buf []byte) (value int64, length int, err error) { - if len(buf) < 8 { - return value, length, errReadI64 - } - value = int64(binary.BigEndian.Uint64(buf)) - return value, 8, nil + value, length, err = gthrift.Binary.ReadI64(buf) + return } -var errReadDouble = gthrift.NewProtocolException(gthrift.INVALID_DATA, "[ReadDouble] len(buf) < 8") - func (binaryProtocol) ReadDouble(buf []byte) (value float64, length int, err error) { - if len(buf) < 8 { - return value, length, errReadDouble - } - value = math.Float64frombits(binary.BigEndian.Uint64(buf)) - return value, 8, nil + value, length, err = gthrift.Binary.ReadDouble(buf) + return } -var errReadString = gthrift.NewProtocolException( - gthrift.INVALID_DATA, "[ReadString] the string size greater than buf length") - func (binaryProtocol) ReadString(buf []byte) (value string, length int, err error) { - size, l, e := Binary.ReadI32(buf) - length += l - if e != nil { - err = e - return - } - if size < 0 || int(size) > len(buf) { - return value, length, errReadString - } - if spanCacheEnable { - data := spanCache.Copy(buf[length : length+int(size)]) - value = sliceByteToString(data) - } else { - value = string(buf[length : length+int(size)]) - } - length += int(size) + // can not inline + value, length, err = gthrift.Binary.ReadString(buf) return } -var errReadBinary = gthrift.NewProtocolException( - gthrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length") - func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err error) { - _size, l, e := Binary.ReadI32(buf) - length += l - if e != nil { - err = e - return - } - size := int(_size) - if size < 0 || size > len(buf) { - return value, length, errReadBinary - } - if spanCacheEnable { - value = spanCache.Copy(buf[length : length+size]) - } else { - value = make([]byte, size) - copy(value, buf[length:length+size]) - } - length += size + // can not inline + value, length, err = gthrift.Binary.ReadBinary(buf) return } // Skip . func (binaryProtocol) Skip(buf []byte, fieldType thrift.TType) (length int, err error) { - return SkipDefaultDepth(buf, Binary, fieldType) -} - -// SkipDefaultDepth skips over the next data element from the provided input TProtocol object. -func SkipDefaultDepth(buf []byte, prot BTProtocol, typeID thrift.TType) (int, error) { - const defaultRecursionDepth = 64 // same as thrift.DEFAULT_RECURSION_DEPTH - return Skip(buf, prot, typeID, defaultRecursionDepth) -} - -var errSkipDepthLimit = gthrift.NewProtocolException(gthrift.DEPTH_LIMIT, "depth limit exceeded") - -// Skip skips over the next data element from the provided input TProtocol object. -func Skip(buf []byte, self BTProtocol, fieldType thrift.TType, maxDepth int) (length int, err error) { - if maxDepth <= 0 { - return 0, errSkipDepthLimit - } - - var l int - switch fieldType { - case thrift.BOOL: - length += 1 - return - case thrift.BYTE: - length += 1 - return - case thrift.I16: - length += 2 - return - case thrift.I32: - length += 4 - return - case thrift.I64: - length += 8 - return - case thrift.DOUBLE: - length += 8 - return - case thrift.STRING: - var sl int32 - sl, l, err = self.ReadI32(buf) - length += l + int(sl) - return - case thrift.STRUCT: - _, l, err = self.ReadStructBegin(buf) - length += l - if err != nil { - return - } - for { - _, typeID, _, l, e := self.ReadFieldBegin(buf[length:]) - length += l - if e != nil { - err = e - return - } - if typeID == thrift.STOP { - break - } - l, e = Skip(buf[length:], self, typeID, maxDepth-1) - length += l - if e != nil { - err = e - return - } - l, e = self.ReadFieldEnd(buf[length:]) - length += l - if e != nil { - err = e - return - } - } - l, e := self.ReadStructEnd(buf[length:]) - length += l - if e != nil { - err = e - } - return - case thrift.MAP: - keyType, valueType, size, l, e := self.ReadMapBegin(buf) - length += l - if e != nil { - err = e - return - } - for i := 0; i < size; i++ { - l, e := Skip(buf[length:], self, keyType, maxDepth-1) - length += l - if e != nil { - err = e - return - } - l, e = Skip(buf[length:], self, valueType, maxDepth-1) - length += l - if e != nil { - err = e - return - } - } - l, e = self.ReadMapEnd(buf[length:]) - length += l - if e != nil { - err = e - } - return - case thrift.SET: - elemType, size, l, e := self.ReadSetBegin(buf) - length += l - if e != nil { - err = e - return - } - for i := 0; i < size; i++ { - l, e = Skip(buf[length:], self, elemType, maxDepth-1) - length += l - if e != nil { - err = e - return - } - } - l, e = self.ReadSetEnd(buf[length:]) - length += l - if e != nil { - err = e - } - return - case thrift.LIST: - elemType, size, l, e := self.ReadListBegin(buf) - length += l - if e != nil { - err = e - return - } - for i := 0; i < size; i++ { - l, e = Skip(buf[length:], self, elemType, maxDepth-1) - length += l - if e != nil { - err = e - return - } - } - l, e = self.ReadListEnd(buf[length:]) - length += l - if e != nil { - err = e - } - return - default: - return 0, gthrift.NewProtocolException( - gthrift.INVALID_DATA, fmt.Sprintf("unknown data type %d", fieldType)) - } + // can not inline + length, err = gthrift.Binary.Skip(buf, gthrift.TType(fieldType)) + return } diff --git a/pkg/protocol/bthrift/compat.go b/pkg/protocol/bthrift/compat.go deleted file mode 100644 index 0e9cb58494..0000000000 --- a/pkg/protocol/bthrift/compat.go +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package bthrift - -import ( - "errors" - "io" - - "github.com/cloudwego/gopkg/protocol/thrift" - - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -// ApacheCodecAdapter converts a fastcodec struct to apache codec -type ApacheCodecAdapter struct { - p thrift.FastCodec -} - -// Write implements athrift.TStruct -func (p ApacheCodecAdapter) Write(tp athrift.TProtocol) error { - b := make([]byte, p.p.BLength()) - b = b[:p.p.FastWriteNocopy(b, nil)] - _, err := tp.Transport().Write(b) - return err -} - -// Read implements athrift.TStruct -func (p ApacheCodecAdapter) Read(tp athrift.TProtocol) error { - var err error - var b []byte - trans := tp.Transport() - n := trans.RemainingBytes() - if int64(n) < 0 { - return errors.New("unknown buffer len") - } - b = make([]byte, n) - _, err = io.ReadFull(trans, b) - if err == nil { - _, err = p.p.FastRead(b) - } - return err -} - -// ToApacheCodec converts a thrift.FastCodec to athrift.TStruct -func ToApacheCodec(p thrift.FastCodec) athrift.TStruct { - return &ApacheCodecAdapter{p} -} - -// UnpackApacheCodec unpacks the value returned by `ToApacheCodec` -func UnpackApacheCodec(v interface{}) interface{} { - a, ok := v.(*ApacheCodecAdapter) - if ok { - return a.p - } - return v -} diff --git a/pkg/protocol/bthrift/utils.go b/pkg/protocol/bthrift/utils.go deleted file mode 100644 index 2fd5b8f527..0000000000 --- a/pkg/protocol/bthrift/utils.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package bthrift - -import ( - "reflect" - "unsafe" -) - -// from utils.SliceByteToString for fixing cyclic import -func sliceByteToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) -} - -// from utils.StringToSliceByte for fixing cyclic import -func stringToSliceByte(s string) []byte { - p := unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data) - var b []byte - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - hdr.Data = uintptr(p) - hdr.Cap = len(s) - hdr.Len = len(s) - return b -} diff --git a/pkg/remote/bufiox2buffer.go b/pkg/remote/bufiox2buffer.go new file mode 100644 index 0000000000..7f2e06db40 --- /dev/null +++ b/pkg/remote/bufiox2buffer.go @@ -0,0 +1,61 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package remote + +import ( + "io" + + "github.com/cloudwego/gopkg/bufiox" +) + +type bufioxBuffer struct { + io.ReadWriter + bufiox.Writer + bufiox.Reader +} + +// NewByteBufferFromBufiox is for compatibility with bufiox and ByteBuffer interfaces. +func NewByteBufferFromBufiox(bw bufiox.Writer, br bufiox.Reader) ByteBuffer { + return &bufioxBuffer{ + Writer: bw, + Reader: br, + } +} + +func (b *bufioxBuffer) ReadableLen() (n int) { + panic("not implement") +} + +func (b *bufioxBuffer) ReadString(n int) (s string, err error) { + panic("not implement") +} + +func (b *bufioxBuffer) WriteString(s string) (n int, err error) { + panic("not implement") +} + +func (b *bufioxBuffer) NewBuffer() ByteBuffer { + panic("not implement") +} + +func (b *bufioxBuffer) AppendBuffer(buf ByteBuffer) (err error) { + panic("not implement") +} + +func (b *bufioxBuffer) Bytes() (buf []byte, err error) { + panic("not implement") +} diff --git a/pkg/remote/bytebuf.go b/pkg/remote/bytebuf.go index bbd5dbbd29..2d9fea6437 100644 --- a/pkg/remote/bytebuf.go +++ b/pkg/remote/bytebuf.go @@ -75,13 +75,13 @@ type ByteBuffer interface { // ReadBinary like ReadString. // Returns a copy of original buffer. - ReadBinary(n int) (p []byte, err error) + ReadBinary(p []byte) (n int, err error) // Malloc n bytes sequentially in the writer buffer. Malloc(n int) (buf []byte, err error) - // MallocLen returns the total length of the buffer malloced. - MallocLen() (length int) + // WrittenLen returns the total length of the buffer writtenLen. + WrittenLen() (length int) // WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte. WriteString(s string) (n int, err error) diff --git a/pkg/remote/codec/default_codec.go b/pkg/remote/codec/default_codec.go index 2dde97837e..9e837ffa3d 100644 --- a/pkg/remote/codec/default_codec.go +++ b/pkg/remote/codec/default_codec.go @@ -106,16 +106,16 @@ type defaultCodec struct { // EncodePayload encode payload func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { defer func() { - // notice: mallocLen() must exec before flush, or it will be reset + // notice: WrittenLen() must exec before flush, or it will be reset if ri := message.RPCInfo(); ri != nil { if ms := rpcinfo.AsMutableRPCStats(ri.Stats()); ms != nil { - ms.SetSendSize(uint64(out.MallocLen())) + ms.SetSendSize(uint64(out.WrittenLen())) } } }() var err error var framedLenField []byte - headerLen := out.MallocLen() + headerLen := out.WrittenLen() tp := message.ProtocolInfo().TransProto // 1. malloc framed field if needed @@ -137,14 +137,14 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message if framedLenField == nil { return perrors.NewProtocolErrorWithMsg("no buffer allocated for the framed length field") } - payloadLen = out.MallocLen() - headerLen + payloadLen = out.WrittenLen() - headerLen // FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `framedLenField` may not take effect. binary.BigEndian.PutUint32(framedLenField, uint32(payloadLen)) } else if message.ProtocolInfo().CodecType == serviceinfo.Protobuf { return perrors.NewProtocolErrorWithMsg("protobuf just support 'framed' trans proto") } if tp&transport.TTHeader == transport.TTHeader { - payloadLen = out.MallocLen() - Size32 + payloadLen = out.WrittenLen() - Size32 } err = checkPayloadSize(payloadLen, c.MaxSize) return err @@ -176,7 +176,7 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote. return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field") } // FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `totalLenField` may not take effect. - payloadLen := out.MallocLen() - Size32 + payloadLen := out.WrittenLen() - Size32 binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen)) } return nil @@ -310,7 +310,7 @@ func (c *defaultCodec) encodeMetaAndPayloadWithPayloadValidator(ctx context.Cont if totalLenField == nil { return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field") } - payloadLen := out.MallocLen() - Size32 + payloadLen := out.WrittenLen() - Size32 binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen)) return err } diff --git a/pkg/remote/codec/default_codec_test.go b/pkg/remote/codec/default_codec_test.go index b08e2bd000..dbe414bef4 100644 --- a/pkg/remote/codec/default_codec_test.go +++ b/pkg/remote/codec/default_codec_test.go @@ -521,7 +521,7 @@ func TestCornerCase(t *testing.T) { sendMsg.SetProtocolInfo(remote.NewProtocolInfo(transport.Framed, serviceinfo.Thrift)) buffer := mocksremote.NewMockByteBuffer(ctrl) - buffer.EXPECT().MallocLen().Return(1024).AnyTimes() + buffer.EXPECT().WrittenLen().Return(1024).AnyTimes() buffer.EXPECT().Malloc(gomock.Any()).Return(nil, errors.New("error malloc")).AnyTimes() err := (&defaultCodec{}).EncodePayload(context.Background(), sendMsg, buffer) test.Assert(t, err.Error() == "error malloc") diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index fb4e8ad051..a50680dbaa 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -48,7 +48,7 @@ func TestTTHeaderCodec(t *testing.T) { // encode buf := tb.NewBuffer() totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, buf) - binary.BigEndian.PutUint32(totalLenField, uint32(buf.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(buf.WrittenLen()-Size32+mockPayloadLen)) test.Assert(t, err == nil, err) buf.Flush() @@ -76,7 +76,7 @@ func TestTTHeaderCodecWithTransInfo(t *testing.T) { // encode buf := tb.NewBuffer() totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, buf) - binary.BigEndian.PutUint32(totalLenField, uint32(buf.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(buf.WrittenLen()-Size32+mockPayloadLen)) test.Assert(t, err == nil, err) buf.Flush() @@ -111,7 +111,7 @@ func TestTTHeaderCodecWithTransInfoWithGDPRToken(t *testing.T) { // encode buf := tb.NewBuffer() totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, buf) - binary.BigEndian.PutUint32(totalLenField, uint32(buf.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(buf.WrittenLen()-Size32+mockPayloadLen)) test.Assert(t, err == nil, err) buf.Flush() @@ -147,7 +147,7 @@ func TestTTHeaderCodecWithTransInfoFromMetaInfoGDPRToken(t *testing.T) { // encode buf := tb.NewBuffer() totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, buf) - binary.BigEndian.PutUint32(totalLenField, uint32(buf.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(buf.WrittenLen()-Size32+mockPayloadLen)) test.Assert(t, err == nil, err) buf.Flush() @@ -179,7 +179,7 @@ func TestFillBasicInfoOfTTHeader(t *testing.T) { sendMsg.TransInfo().TransIntInfo()[transmeta.FromService] = mockServiceName buf := tb.NewBuffer() totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, buf) - binary.BigEndian.PutUint32(totalLenField, uint32(buf.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(buf.WrittenLen()-Size32+mockPayloadLen)) test.Assert(t, err == nil, err) buf.Flush() // decode @@ -196,7 +196,7 @@ func TestFillBasicInfoOfTTHeader(t *testing.T) { sendMsg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr] = mockAddr buf = tb.NewBuffer() totalLenField, err = ttHeaderCodec.encode(ctx, sendMsg, buf) - binary.BigEndian.PutUint32(totalLenField, uint32(buf.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(buf.WrittenLen()-Size32+mockPayloadLen)) test.Assert(t, err == nil, err) buf.Flush() // decode @@ -219,7 +219,7 @@ func BenchmarkTTHeaderCodec(b *testing.B) { // encode out := remote.NewWriterBuffer(256) totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out) - binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(out.WrittenLen()-Size32+mockPayloadLen)) test.Assert(b, err == nil, err) // decode @@ -249,7 +249,7 @@ func BenchmarkTTHeaderWithTransInfoParallel(b *testing.B) { // encode out := remote.NewWriterBuffer(256) totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out) - binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(out.WrittenLen()-Size32+mockPayloadLen)) test.Assert(b, err == nil, err) // decode @@ -282,7 +282,7 @@ func BenchmarkTTHeaderCodecParallel(b *testing.B) { // encode out := remote.NewWriterBuffer(256) totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out) - binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen)) + binary.BigEndian.PutUint32(totalLenField, uint32(out.WrittenLen()-Size32+mockPayloadLen)) test.Assert(b, err == nil, err) // decode @@ -450,7 +450,7 @@ func (t ttHeader) encode2(ctx context.Context, message remote.Message, payloadBu padding := (4 - headerInfoSize%4) % 4 headerInfoSize += padding binary.BigEndian.PutUint16(headerSizeField, uint16(headerInfoSize/4)) - totalLen := TTHeaderMetaSize - Size32 + headerInfoSize + payloadBuf.MallocLen() + totalLen := TTHeaderMetaSize - Size32 + headerInfoSize + payloadBuf.WrittenLen() binary.BigEndian.PutUint32(totalLenField, uint32(totalLen)) // 3. header info, malloc and write diff --git a/pkg/remote/codec/thrift/deprecated.go b/pkg/remote/codec/thrift/deprecated.go index b6f0029d9f..2bf1cce3d9 100644 --- a/pkg/remote/codec/thrift/deprecated.go +++ b/pkg/remote/codec/thrift/deprecated.go @@ -37,7 +37,29 @@ type MessageWriter interface { // TODO: this func should be removed in the future. it's exposed accidentally. // Deprecated: Use `SkipDecoder` + `ApplicationException` of `cloudwego/gopkg/protocol/thrift` instead. func UnmarshalThriftException(tProt athrift.TProtocol) error { - return unmarshalThriftException(tProt.Transport()) + var m string + var t int32 + for { + _, tp, id, err := tProt.ReadFieldBegin() + if err != nil { + return err + } + if tp == athrift.STOP { + break + } + switch { + case id == 1 && tp == athrift.STRING: // Msg + m, err = tProt.ReadString() + case id == 2 && tp == athrift.I32: // TypeID + t, err = tProt.ReadI32() + default: + err = tProt.Skip(tp) + } + if err != nil { + return err + } + } + return remote.NewTransErrorWithMsg(t, m) } // BinaryProtocol ... @@ -47,5 +69,5 @@ type BinaryProtocol = athrift.BinaryProtocol // NewBinaryProtocol ... // Deprecated: use github.com/apache/thrift/lib/go/thrift.NewTBinaryProtocol func NewBinaryProtocol(t remote.ByteBuffer) *athrift.BinaryProtocol { - return athrift.NewBinaryProtocol(t) + return athrift.NewBinaryProtocol(t, t) } diff --git a/pkg/remote/codec/thrift/deprecated_test.go b/pkg/remote/codec/thrift/deprecated_test.go index 3b57615903..795cdf7bb9 100644 --- a/pkg/remote/codec/thrift/deprecated_test.go +++ b/pkg/remote/codec/thrift/deprecated_test.go @@ -18,17 +18,19 @@ package thrift import ( "context" - "encoding/binary" "testing" + "github.com/cloudwego/gopkg/bufiox" + + "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" - "github.com/cloudwego/kitex/pkg/remote" ) func TestMessage(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) // check write name := "name" @@ -37,12 +39,6 @@ func TestMessage(t *testing.T) { err := prot.WriteMessageBegin(name, typeID, seqID) test.Assert(t, err == nil, err) - tmp, _ := trans.Bytes() - test.Assert(t, binary.BigEndian.Uint32(tmp[:4]) == 0x80010001) - test.Assert(t, binary.BigEndian.Uint32(tmp[4:8]) == uint32(len(name))) - test.Assert(t, string(tmp[8:8+len(name)]) == name) - test.Assert(t, binary.BigEndian.Uint32(tmp[8+len(name):12+len(name)]) == uint32(seqID)) - err = prot.WriteMessageEnd() test.Assert(t, err == nil, err) @@ -60,8 +56,9 @@ func TestMessage(t *testing.T) { } func TestStruct(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) name := "struct" err := prot.WriteStructBegin(name) @@ -77,8 +74,9 @@ func TestStruct(t *testing.T) { } func TestField(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) name := "name" var fieldID int16 = 1 @@ -88,11 +86,6 @@ func TestField(t *testing.T) { test.Assert(t, err == nil, err) err = prot.WriteFieldStop() test.Assert(t, err == nil, err) - tmp, _ := trans.Bytes() - test.Assert(t, len(tmp) == 4) - test.Assert(t, tmp[0] == thrift.STRUCT) - test.Assert(t, binary.BigEndian.Uint16(tmp[1:]) == uint16(fieldID)) - test.Assert(t, tmp[3] == thrift.STOP) err = prot.Flush(context.Background()) test.Assert(t, err == nil, err) @@ -112,19 +105,15 @@ func TestField(t *testing.T) { } func TestMap(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) size := 10 err := prot.WriteMapBegin(thrift.I32, thrift.BOOL, size) test.Assert(t, err == nil, err) err = prot.WriteMapEnd() test.Assert(t, err == nil, err) - tmp, _ := trans.Bytes() - test.Assert(t, len(tmp) == 6) - test.Assert(t, tmp[0] == thrift.I32) - test.Assert(t, tmp[1] == thrift.BOOL) - test.Assert(t, binary.BigEndian.Uint32(tmp[2:]) == uint32(size)) err = prot.Flush(context.Background()) test.Assert(t, err == nil, err) @@ -139,18 +128,15 @@ func TestMap(t *testing.T) { } func TestList(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) size := 10 err := prot.WriteListBegin(thrift.I64, size) test.Assert(t, err == nil, err) err = prot.WriteListEnd() test.Assert(t, err == nil, err) - tmp, _ := trans.Bytes() - test.Assert(t, len(tmp) == 5) - test.Assert(t, tmp[0] == thrift.I64) - test.Assert(t, binary.BigEndian.Uint32(tmp[1:]) == uint32(size)) err = prot.Flush(context.Background()) test.Assert(t, err == nil, err) @@ -164,18 +150,15 @@ func TestList(t *testing.T) { } func TestSet(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) size := 10 err := prot.WriteSetBegin(thrift.STRING, size) test.Assert(t, err == nil, err) err = prot.WriteSetEnd() test.Assert(t, err == nil, err) - tmp, _ := trans.Bytes() - test.Assert(t, len(tmp) == 5) - test.Assert(t, tmp[0] == thrift.STRING) - test.Assert(t, binary.BigEndian.Uint32(tmp[1:]) == uint32(size)) err = prot.Flush(context.Background()) test.Assert(t, err == nil, err) @@ -189,8 +172,9 @@ func TestSet(t *testing.T) { } func TestConst(t *testing.T) { - trans := remote.NewReaderWriterBuffer(-1) - prot := NewBinaryProtocol(trans) + conn := mocks.NewIOConn() + bw, br := bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) + prot := thrift.NewBinaryProtocol(br, bw) n := 0 err := prot.WriteBool(false) @@ -217,19 +201,11 @@ func TestConst(t *testing.T) { err = prot.WriteBinary([]byte{7}) n += 4 + 1 test.Assert(t, err == nil, err) + test.Assert(t, bw.WrittenLen() == n) + err = prot.Flush(context.Background()) test.Assert(t, err == nil, err) - tmp, _ := trans.Bytes() - test.Assert(t, len(tmp) == n, len(tmp)) - test.Assert(t, tmp[0] == 0x0) - test.Assert(t, tmp[1] == 0x1) - test.Assert(t, tmp[3] == 0x2) - test.Assert(t, tmp[7] == 0x3) - test.Assert(t, tmp[15] == 0x4) - test.Assert(t, string(tmp[28:29]) == "6") - test.Assert(t, tmp[33] == 0x7) - err = prot.Flush(context.Background()) test.Assert(t, err == nil, err) diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index ce175342f9..c2bb587d1d 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -20,8 +20,8 @@ import ( "context" "errors" "fmt" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/apache" @@ -167,7 +167,7 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re } // encodeFastThrift encode with the FastCodec way -func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg thrift.FastCodec) error { +func encodeFastThrift(out bufiox.Writer, methodName string, msgType remote.MessageType, seqID int32, msg thrift.FastCodec) error { nw, _ := out.(remote.NocopyWrite) // nocopy write is a special implementation of linked buffer, only bytebuffer implement NocopyWrite do FastWrite msgBeginLen := thrift.Binary.MessageBeginLength(methodName) @@ -177,7 +177,7 @@ func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.M } // If fast write enabled, the underlying buffer maybe large than the correct buffer, // so we need to save the mallocLen before fast write and correct the real mallocLen after codec - mallocLen := out.MallocLen() + mallocLen := out.WrittenLen() offset := thrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) _ = msg.FastWriteNocopy(buf[offset:], nw) if nw == nil { @@ -187,12 +187,12 @@ func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.M return nw.MallocAck(mallocLen) } -func encodeGenericThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, msg genericWriter) error { - binaryWriter := thrift.NewBinaryWriter() - binaryWriter.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID) - if _, err := out.Write(binaryWriter.Bytes()); err != nil { +func encodeGenericThrift(out bufiox.Writer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, msg genericWriter) error { + binaryWriter := thrift.NewBufferWriter(out) + if err := binaryWriter.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) } + binaryWriter.Recycle() if err := msg.Write(ctx, method, out); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error())) } @@ -200,7 +200,7 @@ func encodeGenericThrift(out remote.ByteBuffer, ctx context.Context, method stri } // encodeBasicThrift encode with the old apache thrift way (slow) -func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error { +func encodeBasicThrift(out bufiox.Writer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error { if err := verifyMarshalBasicThriftDataType(data); err != nil { return err } @@ -221,8 +221,8 @@ func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { // TODO(xiaost): Refactor the code after v0.11.0 is released. Unifying checking and fallback logic. - br := thrift.NewBinaryReader(in) - defer br.Release() + br := thrift.NewBufferReader(in) + defer br.Recycle() methodName, msgType, seqID, err := br.ReadMessageBegin() if err != nil { @@ -291,11 +291,11 @@ func (c thriftCodec) Name() string { } type genericWriter interface { // used by pkg/generic - Write(ctx context.Context, method string, w io.Writer) error + Write(ctx context.Context, method string, w bufiox.Writer) error } type genericReader interface { // used by pkg/generic - Read(ctx context.Context, method string, dataLen int, r io.Reader) error + Read(ctx context.Context, method string, dataLen int, r bufiox.Reader) error } // ThriftMsgFastCodec ... diff --git a/pkg/remote/codec/thrift/thrift_data.go b/pkg/remote/codec/thrift/thrift_data.go index 4ff0ca9d9c..b9cdb9d112 100644 --- a/pkg/remote/codec/thrift/thrift_data.go +++ b/pkg/remote/codec/thrift/thrift_data.go @@ -17,11 +17,10 @@ package thrift import ( - "bytes" "context" "fmt" - "io" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/apache" @@ -80,11 +79,13 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([ } // fallback to old thrift way (slow) - buf := bytes.NewBuffer(make([]byte, 0, marshalThriftBufferSize)) - if err := apache.ThriftWrite(buf, data); err != nil { + buf := make([]byte, 0, marshalThriftBufferSize) + bw := bufiox.NewBytesWriter(&buf) + if err := apache.ThriftWrite(bw, data); err != nil { return nil, err } - return buf.Bytes(), nil + _ = bw.Flush() + return buf, nil } // verifyMarshalBasicThriftDataType verifies whether data could be marshaled by old thrift way @@ -95,7 +96,7 @@ func verifyMarshalBasicThriftDataType(data interface{}) error { return nil } -func unmarshalThriftException(in io.Reader) error { +func unmarshalThriftException(in bufiox.Reader) error { d := thrift.NewSkipDecoder(in) defer d.Release() b, err := d.Next(thrift.STRUCT) @@ -117,7 +118,7 @@ func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method if !ok { c = defaultCodec } - trans := remote.NewReaderBuffer(buf) + trans := bufiox.NewBytesReader(buf) defer trans.Release(nil) return c.unmarshalThriftData(trans, data, len(buf)) } @@ -130,7 +131,7 @@ func (c thriftCodec) fastMessageUnmarshalAvailable(data interface{}, payloadLen return ok } -func (c thriftCodec) fastUnmarshal(trans remote.ByteBuffer, data interface{}, dataLen int) error { +func (c thriftCodec) fastUnmarshal(trans bufiox.Reader, data interface{}, dataLen int) error { msg := data.(thrift.FastCodec) if dataLen > 0 { buf, err := trans.Next(dataLen) @@ -156,7 +157,7 @@ func (c thriftCodec) fastUnmarshal(trans remote.ByteBuffer, data interface{}, da // unmarshalThriftData only decodes the data (after methodName, msgType and seqId) // method is only used for generic calls -func (c thriftCodec) unmarshalThriftData(trans remote.ByteBuffer, data interface{}, dataLen int) error { +func (c thriftCodec) unmarshalThriftData(trans bufiox.Reader, data interface{}, dataLen int) error { // decode with hyper unmarshal if c.IsSet(FrugalRead) && c.hyperMessageUnmarshalAvailable(data, dataLen) { return c.hyperUnmarshal(trans, data, dataLen) @@ -185,7 +186,7 @@ func (c thriftCodec) unmarshalThriftData(trans remote.ByteBuffer, data interface return decodeBasicThriftData(trans, data) } -func (c thriftCodec) hyperUnmarshal(trans remote.ByteBuffer, data interface{}, dataLen int) error { +func (c thriftCodec) hyperUnmarshal(trans bufiox.Reader, data interface{}, dataLen int) error { if dataLen > 0 { buf, err := trans.Next(dataLen) if err != nil { @@ -216,7 +217,7 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error { } // decodeBasicThriftData decode thrift body the old way (slow) -func decodeBasicThriftData(trans remote.ByteBuffer, data interface{}) error { +func decodeBasicThriftData(trans bufiox.Reader, data interface{}) error { var err error if err = verifyUnmarshalBasicThriftDataType(data); err != nil { return err @@ -227,7 +228,7 @@ func decodeBasicThriftData(trans remote.ByteBuffer, data interface{}) error { return nil } -func getSkippedStructBuffer(trans remote.ByteBuffer) ([]byte, error) { +func getSkippedStructBuffer(trans bufiox.Reader) ([]byte, error) { sd := thrift.NewSkipDecoder(trans) buf, err := sd.Next(thrift.STRUCT) if err != nil { diff --git a/pkg/remote/codec/thrift/thrift_data_test.go b/pkg/remote/codec/thrift/thrift_data_test.go index ca275936b8..323637c777 100644 --- a/pkg/remote/codec/thrift/thrift_data_test.go +++ b/pkg/remote/codec/thrift/thrift_data_test.go @@ -22,11 +22,13 @@ import ( "strings" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" + athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/remote" ) @@ -54,32 +56,11 @@ func TestMarshalThriftData(t *testing.T) { test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) }) t.Run("BasicCodec", func(t *testing.T) { - buf, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), bthrift.ToApacheCodec(mockReq)) - test.Assert(t, err == nil, err) - test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) - }) - // FrugalCodec: in thrift_frugal_amd64_test.go: TestMarshalThriftDataFrugal -} - -func Test_decodeBasicThriftData(t *testing.T) { - t.Run("empty-input", func(t *testing.T) { - req := &mocks.MockReq{} - trans := remote.NewReaderBuffer([]byte{}) - err := decodeBasicThriftData(trans, bthrift.ToApacheCodec(req)) + _, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), mockReq) test.Assert(t, err != nil, err) + // test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf) }) - t.Run("invalid-input", func(t *testing.T) { - req := &mocks.MockReq{} - trans := remote.NewReaderBuffer([]byte{0xff}) - err := decodeBasicThriftData(trans, bthrift.ToApacheCodec(req)) - test.Assert(t, err != nil, err) - }) - t.Run("normal-input", func(t *testing.T) { - req := &mocks.MockReq{} - trans := remote.NewReaderBuffer(mockReqThrift) - err := decodeBasicThriftData(trans, bthrift.ToApacheCodec(req)) - checkDecodeResult(t, err, req) - }) + // FrugalCodec: in thrift_frugal_amd64_test.go: TestMarshalThriftDataFrugal } func checkDecodeResult(t *testing.T, err error, req *mocks.MockReq) { @@ -103,8 +84,8 @@ func TestUnmarshalThriftData(t *testing.T) { }) t.Run("BasicCodec", func(t *testing.T) { req := &mocks.MockReq{} - err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, bthrift.ToApacheCodec(req)) - checkDecodeResult(t, err, req) + err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, req) + test.Assert(t, err != nil, err) }) // FrugalCodec: in thrift_frugal_amd64_test.go: TestUnmarshalThriftDataFrugal } @@ -113,7 +94,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { t.Run("FastCodec with SkipDecoder enabled", func(t *testing.T) { req := &mocks.MockReq{} codec := &thriftCodec{FastRead | EnableSkipDecoder} - trans := remote.NewReaderBuffer(mockReqThrift) + trans := bufiox.NewBytesReader(mockReqThrift) // specify dataLen with 0 so that skipDecoder works err := codec.unmarshalThriftData(trans, req, 0) checkDecodeResult(t, err, &mocks.MockReq{ @@ -136,7 +117,7 @@ func TestThriftCodec_unmarshalThriftData(t *testing.T) { 15 /* list */, 0, 3 /* id=3 */, 6 /* item:I16 */, 0, 0, 0, 1 /* length=1 */, 0, 1, /* I16=1 */ 0, /* end of struct */ } - trans := remote.NewReaderBuffer(faultMockReqThrift) + trans := bufiox.NewBytesReader(faultMockReqThrift) // specify dataLen with 0 so that skipDecoder works err := codec.unmarshalThriftData(trans, req, 0) test.Assert(t, err != nil, err) @@ -153,7 +134,7 @@ func TestUnmarshalThriftException(t *testing.T) { test.Assert(t, n == len(b), n) // unmarshal - tProtRead := NewBinaryProtocol(remote.NewReaderBuffer(b)) + tProtRead := athrift.NewBinaryProtocol(bufiox.NewBytesReader(b), nil) err := UnmarshalThriftException(tProtRead) transErr, ok := err.(*remote.TransError) test.Assert(t, ok, err) @@ -166,7 +147,7 @@ func Test_getSkippedStructBuffer(t *testing.T) { faultThrift := []byte{ 11 /* string */, 0, 1 /* id=1 */, 0, 0, 0, 6 /* length=6 */, 104, 101, 108, 108, 111, /* "hello" */ } - trans := remote.NewReaderBuffer(faultThrift) + trans := bufiox.NewBytesReader(faultThrift) _, err := getSkippedStructBuffer(trans) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), "caught in SkipDecoder Next phase")) diff --git a/pkg/remote/codec/thrift/thrift_frugal.go b/pkg/remote/codec/thrift/thrift_frugal.go index 173095a28f..fc77bc34d2 100644 --- a/pkg/remote/codec/thrift/thrift_frugal.go +++ b/pkg/remote/codec/thrift/thrift_frugal.go @@ -21,6 +21,7 @@ import ( "reflect" "github.com/cloudwego/frugal" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/internal/utils/safemcache" @@ -51,7 +52,7 @@ func (c thriftCodec) hyperMessageUnmarshalAvailable(data interface{}, payloadLen return true } -func (c thriftCodec) hyperMarshal(out remote.ByteBuffer, methodName string, msgType remote.MessageType, +func (c thriftCodec) hyperMarshal(out bufiox.Writer, methodName string, msgType remote.MessageType, seqID int32, data interface{}, ) error { // calculate and malloc message buffer @@ -61,7 +62,7 @@ func (c thriftCodec) hyperMarshal(out remote.ByteBuffer, methodName string, msgT if err != nil { return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error())) } - mallocLen := out.MallocLen() + mallocLen := out.WrittenLen() // encode message offset := thrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) diff --git a/pkg/remote/codec/thrift/thrift_frugal_test.go b/pkg/remote/codec/thrift/thrift_frugal_test.go index ba2f83f999..de739eebaf 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_test.go @@ -22,6 +22,8 @@ import ( "strings" "testing" + "github.com/cloudwego/gopkg/bufiox" + mocks "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" @@ -98,10 +100,11 @@ func TestFrugalCodec(t *testing.T) { // MockNoTagArgs cannot be marshaled sendMsg := initNoTagSendMsg(transport.TTHeader) - out := tb.NewBuffer() - err := codec.Marshal(ctx, sendMsg, out) + bw, _ := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, nil) + err := codec.Marshal(ctx, sendMsg, bb) test.Assert(t, err != nil) - out.Flush() + bw.Flush() }) t.Run("configure frugal and data has tag", func(t *testing.T) { ctx := context.Background() @@ -121,9 +124,10 @@ func TestFrugalCodec(t *testing.T) { // MockNoTagArgs cannot be marshaled sendMsg := initNoTagSendMsg(transport.TTHeader) - out := tb.NewBuffer() - err := codec.Marshal(ctx, sendMsg, out) - out.Flush() + bw, _ := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, nil) + err := codec.Marshal(ctx, sendMsg, bb) + bw.Flush() test.Assert(t, err != nil) }) t.Run("configure frugal and SkipDecoder for Buffer Protocol", func(t *testing.T) { @@ -141,18 +145,20 @@ func testFrugalDataConversion(t *testing.T, ctx context.Context, codec remote.Pa t.Run(tb.Name, func(t *testing.T) { // encode client side sendMsg := initFrugalTagSendMsg(protocol) - buf := tb.NewBuffer() - err := codec.Marshal(ctx, sendMsg, buf) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := codec.Marshal(ctx, sendMsg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode server side recvMsg := initFrugalTagRecvMsg() if protocol != transport.PurePayload { - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg.SetPayloadLen(wl) } test.Assert(t, err == nil, err) - err = codec.Unmarshal(ctx, recvMsg, buf) + err = codec.Unmarshal(ctx, recvMsg, bb) test.Assert(t, err == nil, err) // compare Args @@ -221,7 +227,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { t.Run("Frugal with SkipDecoder enabled", func(t *testing.T) { req := &MockFrugalTagReq{} codec := &thriftCodec{FrugalRead | EnableSkipDecoder} - trans := remote.NewReaderBuffer(mockReqThrift) + trans := bufiox.NewBytesReader(mockReqThrift) // specify dataLen with 0 so that skipDecoder works err := codec.unmarshalThriftData(trans, req, 0) checkDecodeResult(t, err, &mocks.MockReq{ @@ -244,7 +250,7 @@ func TestThriftCodec_unmarshalThriftDataFrugal(t *testing.T) { 15 /* list */, 0, 3 /* id=3 */, 6 /* item:I16 */, 0, 0, 0, 1 /* length=1 */, 0, 1, /* I16=1 */ 0, /* end of struct */ } - trans := remote.NewReaderBuffer(faultMockReqThrift) + trans := bufiox.NewBytesReader(faultMockReqThrift) // specify dataLen with 0 so that skipDecoder works err := codec.unmarshalThriftData(trans, req, 0) test.Assert(t, err != nil, err) diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 12308545fa..35ee4dabec 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -19,16 +19,15 @@ package thrift import ( "context" "errors" - "io" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/internal/mocks" mt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" "github.com/cloudwego/kitex/pkg/remote" netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -42,18 +41,20 @@ var ( transportBuffers = []struct { Name string - NewBuffer func() remote.ByteBuffer + NewBuffer func() (bufiox.Writer, bufiox.Reader) }{ { Name: "BytesBuffer", - NewBuffer: func() remote.ByteBuffer { - return remote.NewReaderWriterBuffer(1024) + NewBuffer: func() (bufiox.Writer, bufiox.Reader) { + conn := mocks.NewIOConn() + return bufiox.NewDefaultWriter(conn), bufiox.NewDefaultReader(conn) }, }, { Name: "NetpollBuffer", - NewBuffer: func() remote.ByteBuffer { - return netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) + NewBuffer: func() (bufiox.Writer, bufiox.Reader) { + bb := netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) + return bb, bb }, }, } @@ -64,18 +65,18 @@ func init() { } type mockWithContext struct { - ReadFunc func(ctx context.Context, method string, dataLen int, oprot io.Reader) error - WriteFunc func(ctx context.Context, method string, oprot io.Writer) error + ReadFunc func(ctx context.Context, method string, dataLen int, oprot bufiox.Reader) error + WriteFunc func(ctx context.Context, method string, oprot bufiox.Writer) error } -func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot io.Reader) error { +func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot bufiox.Reader) error { if m.ReadFunc != nil { return m.ReadFunc(ctx, method, dataLen, oprot) } return nil } -func (m *mockWithContext) Write(ctx context.Context, method string, oprot io.Writer) error { +func (m *mockWithContext) Write(ctx context.Context, method string, oprot bufiox.Writer) error { if m.WriteFunc != nil { return m.WriteFunc(ctx, method, oprot) } @@ -87,28 +88,30 @@ func TestWithContext(t *testing.T) { t.Run(tb.Name, func(t *testing.T) { ctx := context.Background() - req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot io.Writer) error { + req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot bufiox.Writer) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec)) - buf := tb.NewBuffer() - err := payloadCodec.Marshal(ctx, msg, buf) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := payloadCodec.Marshal(ctx, msg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() { - resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot io.Reader) error { + resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot bufiox.Reader) error { return nil }} ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) msg := remote.NewMessage(resp, svcInfo, ri, remote.Call, remote.Client) msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec)) - msg.SetPayloadLen(buf.ReadableLen()) - err = payloadCodec.Unmarshal(ctx, msg, buf) + msg.SetPayloadLen(wl) + err = payloadCodec.Unmarshal(ctx, msg, bb) test.Assert(t, err == nil, err) } }) @@ -121,17 +124,19 @@ func TestNormal(t *testing.T) { t.Run(tb.Name, func(t *testing.T) { ctx := context.Background() // encode client side - sendMsg := initSendMsg(transport.TTHeader, false) - buf := tb.NewBuffer() - err := payloadCodec.Marshal(ctx, sendMsg, buf) + sendMsg := initSendMsg(transport.TTHeader) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := payloadCodec.Marshal(ctx, sendMsg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode server side - recvMsg := initRecvMsg(false) - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg := initRecvMsg() + recvMsg.SetPayloadLen(wl) test.Assert(t, err == nil, err) - err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + err = payloadCodec.Unmarshal(ctx, recvMsg, bb) test.Assert(t, err == nil, err) // compare Req Arg @@ -144,17 +149,19 @@ func TestNormal(t *testing.T) { t.Run(tb.Name+"Basic", func(t *testing.T) { ctx := context.Background() // encode client side - sendMsg := initSendMsg(transport.TTHeader, true) - buf := tb.NewBuffer() - err := payloadCodec.Marshal(ctx, sendMsg, buf) + sendMsg := initSendMsg(transport.TTHeader) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := payloadCodec.Marshal(ctx, sendMsg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode server side - recvMsg := initRecvMsg(true) - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg := initRecvMsg() + recvMsg.SetPayloadLen(wl) test.Assert(t, err == nil, err) - err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + err = payloadCodec.Unmarshal(ctx, recvMsg, bb) test.Assert(t, err == nil, err) // compare Req Arg @@ -169,16 +176,18 @@ func TestNormal(t *testing.T) { // encode client side sendMsg := newMsg(remote.NewTransErrorWithMsg(1, "hello")) sendMsg.SetMessageType(remote.Exception) - buf := tb.NewBuffer() - err := payloadCodec.Marshal(ctx, sendMsg, buf) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := payloadCodec.Marshal(ctx, sendMsg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode server side recvMsg := newMsg(nil) - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg.SetPayloadLen(wl) test.Assert(t, err == nil, err) - err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + err = payloadCodec.Unmarshal(ctx, recvMsg, bb) test.Assert(t, err != nil) te, ok := err.(*remote.TransError) test.Assert(t, ok) @@ -196,17 +205,19 @@ func BenchmarkNormalParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { // encode // client side - sendMsg := initSendMsg(transport.TTHeader, false) - buf := tb.NewBuffer() - err := payloadCodec.Marshal(ctx, sendMsg, buf) + sendMsg := initSendMsg(transport.TTHeader) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := payloadCodec.Marshal(ctx, sendMsg, bb) test.Assert(b, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode server side - recvMsg := initRecvMsg(false) - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg := initRecvMsg() + recvMsg.SetPayloadLen(wl) test.Assert(b, err == nil, err) - err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + err = payloadCodec.Unmarshal(ctx, recvMsg, bb) test.Assert(b, err == nil, err) // compare Req Arg @@ -237,16 +248,18 @@ func TestException(t *testing.T) { transErr := remote.NewTransErrorWithMsg(remote.UnknownMethod, errInfo) // encode server side errMsg := initServerErrorMsg(transport.TTHeader, ri, transErr) - buf := tb.NewBuffer() - err := payloadCodec.Marshal(ctx, errMsg, buf) + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := payloadCodec.Marshal(ctx, errMsg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode client side recvMsg := initClientRecvMsg(ri) - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg.SetPayloadLen(wl) test.Assert(t, err == nil, err) - err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + err = payloadCodec.Unmarshal(ctx, recvMsg, bb) test.Assert(t, err != nil) transErr, ok := err.(*remote.TransError) test.Assert(t, ok, err) @@ -275,11 +288,13 @@ func TestSkipDecoder(t *testing.T) { desc string codec remote.PayloadCodec protocol transport.Protocol + wantErr bool }{ { desc: "Disable SkipDecoder, fallback to Apache Thrift Codec for Buffer Protocol", codec: NewThriftCodec(), protocol: transport.PurePayload, + wantErr: true, }, { desc: "Disable SkipDecoder, using FastCodec for TTHeader Protocol", @@ -302,18 +317,24 @@ func TestSkipDecoder(t *testing.T) { for _, tb := range transportBuffers { t.Run(tc.desc+"#"+tb.Name, func(t *testing.T) { // encode client side - sendMsg := initSendMsg(tc.protocol, true) // always use Basic to test skipdecodec - buf := tb.NewBuffer() - err := tc.codec.Marshal(context.Background(), sendMsg, buf) + sendMsg := initSendMsg(tc.protocol) // always use Basic to test skipdecodec + bw, br := tb.NewBuffer() + bb := remote.NewByteBufferFromBufiox(bw, br) + err := tc.codec.Marshal(context.Background(), sendMsg, bb) test.Assert(t, err == nil, err) - buf.Flush() + wl := bw.WrittenLen() + bw.Flush() // decode server side - recvMsg := initRecvMsg(true) + recvMsg := initRecvMsg() if tc.protocol != transport.PurePayload { - recvMsg.SetPayloadLen(buf.ReadableLen()) + recvMsg.SetPayloadLen(wl) + } + err = tc.codec.Unmarshal(context.Background(), recvMsg, bb) + if tc.wantErr { + test.Assert(t, err != nil) + return } - err = tc.codec.Unmarshal(context.Background(), recvMsg, buf) test.Assert(t, err == nil, err) // compare Req Arg @@ -323,35 +344,28 @@ func TestSkipDecoder(t *testing.T) { } } -func toApacheCodec(v bool, data thrift.FastCodec) interface{} { - if v { - return bthrift.ToApacheCodec(data) - } - return data -} - func newMsg(data interface{}) remote.Message { ink := rpcinfo.NewInvocation("", "mock") ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) return remote.NewMessage(data, svcInfo, ri, remote.Call, remote.Client) } -func initSendMsg(tp transport.Protocol, basic bool) remote.Message { +func initSendMsg(tp transport.Protocol) remote.Message { var _args mt.MockTestArgs // fastcodec only, if basic is true -> apachecodec _args.Req = prepareReq() - msg := newMsg(toApacheCodec(basic, &_args)) + msg := newMsg(&_args) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) return msg } -func initRecvMsg(basic bool) remote.Message { +func initRecvMsg() remote.Message { var _args mt.MockTestArgs // fastcodec only, if basic is true -> apachecodec - return newMsg(toApacheCodec(basic, &_args)) + return newMsg(&_args) } func compare(t *testing.T, sendMsg, recvMsg remote.Message) { - sendReq := bthrift.UnpackApacheCodec(sendMsg.Data()).(*mt.MockTestArgs).Req - recvReq := bthrift.UnpackApacheCodec(recvMsg.Data()).(*mt.MockTestArgs).Req + sendReq := sendMsg.Data().(*mt.MockTestArgs).Req + recvReq := recvMsg.Data().(*mt.MockTestArgs).Req test.Assert(t, sendReq.Msg == recvReq.Msg) test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList)) test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap)) diff --git a/pkg/remote/default_bytebuf.go b/pkg/remote/default_bytebuf.go index fe7b3f3a94..9bc4208ea3 100644 --- a/pkg/remote/default_bytebuf.go +++ b/pkg/remote/default_bytebuf.go @@ -182,17 +182,17 @@ func (b *defaultByteBuffer) ReadString(n int) (s string, err error) { // ReadBinary like ReadString. // Returns a copy of original buffer. -func (b *defaultByteBuffer) ReadBinary(n int) (p []byte, err error) { +func (b *defaultByteBuffer) ReadBinary(p []byte) (n int, err error) { if b.status&BitReadable == 0 { - return p, errors.New("unreadable buffer, cannot support ReadBinary") + return 0, errors.New("unreadable buffer, cannot support ReadBinary") } + n = len(p) var buf []byte if buf, err = b.Next(n); err != nil { - return p, err + return 0, err } - p = dirtmake.Bytes(n, n) copy(p, buf) - return p, nil + return n, nil } // Malloc n bytes sequentially in the writer buffer. @@ -206,8 +206,8 @@ func (b *defaultByteBuffer) Malloc(n int) (buf []byte, err error) { return b.buff[currWIdx:b.writeIdx], nil } -// MallocLen returns the total length of the buffer malloced. -func (b *defaultByteBuffer) MallocLen() (length int) { +// WrittenLen returns the total length of the buffer written. +func (b *defaultByteBuffer) WrittenLen() (length int) { if b.status&BitWritable == 0 { return -1 } diff --git a/pkg/remote/default_bytebuf_test.go b/pkg/remote/default_bytebuf_test.go index b3fcfb527e..657de26f4a 100644 --- a/pkg/remote/default_bytebuf_test.go +++ b/pkg/remote/default_bytebuf_test.go @@ -64,7 +64,7 @@ func checkWritable(t *testing.T, buf ByteBuffer) { test.Assert(t, err == nil, err) test.Assert(t, len(p) == len(msg)) copy(p, msg) - l := buf.MallocLen() + l := buf.WrittenLen() test.Assert(t, l == len(msg)) l, err = buf.WriteString(msg) test.Assert(t, err == nil, err) @@ -106,7 +106,8 @@ func checkReadable(t *testing.T, buf ByteBuffer) { s, err = buf.ReadString(len(msg)) test.Assert(t, err == nil, err) test.Assert(t, s == msg) - p, err = buf.ReadBinary(len(msg)) + p = make([]byte, len(msg)) + _, err = buf.ReadBinary(p) test.Assert(t, err == nil, err) test.Assert(t, string(p) == msg) p = make([]byte, len(msg)) @@ -120,7 +121,7 @@ func checkUnwritable(t *testing.T, buf ByteBuffer) { msg := "hello world" _, err := buf.Malloc(len(msg)) test.Assert(t, err != nil) - l := buf.MallocLen() + l := buf.WrittenLen() test.Assert(t, l == -1, l) _, err = buf.WriteString(msg) test.Assert(t, err != nil) @@ -148,7 +149,8 @@ func checkUnreadable(t *testing.T, buf ByteBuffer) { test.Assert(t, n == 0) _, err = buf.ReadString(len(msg)) test.Assert(t, err != nil) - _, err = buf.ReadBinary(len(msg)) + b := make([]byte, len(msg)) + _, err = buf.ReadBinary(b) test.Assert(t, err != nil) p := make([]byte, len(msg)) n, err = buf.Read(p) diff --git a/pkg/remote/trans/gonet/bytebuffer.go b/pkg/remote/trans/gonet/bytebuffer.go index dfae2731c2..566d8d6c48 100644 --- a/pkg/remote/trans/gonet/bytebuffer.go +++ b/pkg/remote/trans/gonet/bytebuffer.go @@ -132,13 +132,17 @@ func (rw *bufferReadWriter) ReadString(n int) (s string, err error) { return } -func (rw *bufferReadWriter) ReadBinary(n int) (p []byte, err error) { +func (rw *bufferReadWriter) ReadBinary(p []byte) (n int, err error) { if !rw.readable() { - return p, errors.New("unreadable buffer, cannot support ReadBinary") + return 0, errors.New("unreadable buffer, cannot support ReadBinary") } - if p, err = rw.reader.ReadBinary(n); err == nil { - rw.readSize += n + n = len(p) + var buf []byte + if buf, err = rw.reader.Next(n); err != nil { + return 0, err } + copy(p, buf) + rw.readSize += n return } @@ -163,7 +167,7 @@ func (rw *bufferReadWriter) Malloc(n int) (buf []byte, err error) { return rw.writer.Malloc(n) } -func (rw *bufferReadWriter) MallocLen() (length int) { +func (rw *bufferReadWriter) WrittenLen() (length int) { if !rw.writable() { return -1 } diff --git a/pkg/remote/trans/gonet/bytebuffer_test.go b/pkg/remote/trans/gonet/bytebuffer_test.go index c6da1e621e..82b138175a 100644 --- a/pkg/remote/trans/gonet/bytebuffer_test.go +++ b/pkg/remote/trans/gonet/bytebuffer_test.go @@ -110,7 +110,8 @@ func testRead(t *testing.T, buf remote.ByteBuffer) { } test.Assert(t, s == msg) - p, err = buf.ReadBinary(msgLen) + p = make([]byte, msgLen) + _, err = buf.ReadBinary(p) if err != nil { t.Logf("ReadBinary failed, err=%s", err.Error()) t.FailNow() @@ -140,7 +141,8 @@ func testReadFailed(t *testing.T, buf remote.ByteBuffer) { _, err = buf.ReadString(len(msg)) test.Assert(t, err != nil) - _, err = buf.ReadBinary(len(msg)) + b := make([]byte, len(msg)) + _, err = buf.ReadBinary(b) test.Assert(t, err != nil) n, err = buf.Read(p) @@ -160,7 +162,7 @@ func testWrite(t *testing.T, buf remote.ByteBuffer) { test.Assert(t, len(p) == msgLen) copy(p, msg) - l := buf.MallocLen() + l := buf.WrittenLen() test.Assert(t, l == msgLen) l, err = buf.WriteString(msg) @@ -188,7 +190,7 @@ func testWriteFailed(t *testing.T, buf remote.ByteBuffer) { _, err := buf.Malloc(len(msg)) test.Assert(t, err != nil) - l := buf.MallocLen() + l := buf.WrittenLen() test.Assert(t, l == -1) _, err = buf.WriteString(msg) diff --git a/pkg/remote/trans/invoke/message_test.go b/pkg/remote/trans/invoke/message_test.go index 19a450e507..c1c7ed856b 100644 --- a/pkg/remote/trans/invoke/message_test.go +++ b/pkg/remote/trans/invoke/message_test.go @@ -41,7 +41,8 @@ func Test_message_Request(t *testing.T) { } rb := msg.GetRequestReaderByteBuffer() test.Assert(t, rb != nil) - got, err := rb.ReadBinary(len(want)) + got := make([]byte, len(want)) + _, err = rb.ReadBinary(got) if err != nil { t.Fatal(err) } diff --git a/pkg/remote/trans/netpoll/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf.go index 4a6bb3b5ee..43825d3da2 100644 --- a/pkg/remote/trans/netpoll/bytebuf.go +++ b/pkg/remote/trans/netpoll/bytebuf.go @@ -129,13 +129,17 @@ func (b *netpollByteBuffer) ReadString(n int) (s string, err error) { // ReadBinary like ReadString. // Returns a copy of original buffer. -func (b *netpollByteBuffer) ReadBinary(n int) (p []byte, err error) { +func (b *netpollByteBuffer) ReadBinary(p []byte) (n int, err error) { if b.status&remote.BitReadable == 0 { - return p, errors.New("unreadable buffer, cannot support ReadBinary") + return 0, errors.New("unreadable buffer, cannot support ReadBinary") } - if p, err = b.reader.ReadBinary(n); err == nil { - b.readSize += n + n = len(p) + var buf []byte + if buf, err = b.reader.Next(n); err != nil { + return 0, err } + copy(p, buf) + b.readSize += n return } @@ -155,8 +159,8 @@ func (b *netpollByteBuffer) MallocAck(n int) (err error) { return b.writer.MallocAck(n) } -// MallocLen returns the total length of the buffer malloced. -func (b *netpollByteBuffer) MallocLen() (length int) { +// WrittenLen returns the total length of the buffer written. +func (b *netpollByteBuffer) WrittenLen() (length int) { if b.status&remote.BitWritable == 0 { return -1 } diff --git a/pkg/remote/trans/netpoll/bytebuf_test.go b/pkg/remote/trans/netpoll/bytebuf_test.go index 2642513076..1ecd44a2cf 100644 --- a/pkg/remote/trans/netpoll/bytebuf_test.go +++ b/pkg/remote/trans/netpoll/bytebuf_test.go @@ -98,7 +98,7 @@ func checkWritable(t *testing.T, buf remote.ByteBuffer) { test.Assert(t, len(p) == len(msg)) copy(p, msg) - l := buf.MallocLen() + l := buf.WrittenLen() test.Assert(t, l == len(msg)) l, err = buf.WriteString(msg) @@ -138,7 +138,8 @@ func checkReadable(t *testing.T, buf remote.ByteBuffer) { test.Assert(t, err == nil, err) test.Assert(t, s == msg) - p, err = buf.ReadBinary(len(msg)) + p = make([]byte, len(msg)) + _, err = buf.ReadBinary(p) test.Assert(t, err == nil, err) test.Assert(t, string(p) == msg) } @@ -150,7 +151,7 @@ func checkUnwritable(t *testing.T, buf remote.ByteBuffer) { _, err := buf.Malloc(len(msg)) test.Assert(t, err != nil) - l := buf.MallocLen() + l := buf.WrittenLen() test.Assert(t, l == -1, l) _, err = buf.WriteString(msg) @@ -193,7 +194,8 @@ func checkUnreadable(t *testing.T, buf remote.ByteBuffer) { _, err = buf.ReadString(len(msg)) test.Assert(t, err != nil) - _, err = buf.ReadBinary(len(msg)) + b := make([]byte, len(msg)) + _, err = buf.ReadBinary(b) test.Assert(t, err != nil) n, err = buf.Read(p) diff --git a/pkg/remote/trans/netpoll/http_client_handler_test.go b/pkg/remote/trans/netpoll/http_client_handler_test.go index df4f90a832..b7d4bd1063 100644 --- a/pkg/remote/trans/netpoll/http_client_handler_test.go +++ b/pkg/remote/trans/netpoll/http_client_handler_test.go @@ -180,7 +180,8 @@ func TestSkipToBody(t *testing.T) { err := skipToBody(reader) test.Assert(t, err == nil) - getBody, err := reader.ReadBinary(reader.ReadableLen()) + getBody := make([]byte, reader.ReadableLen()) + _, err = reader.ReadBinary(getBody) test.Assert(t, err == nil) test.Assert(t, strings.Compare(string(getBody), wantBody) == 0) } diff --git a/pkg/remote/trans/nphttp2/buffer.go b/pkg/remote/trans/nphttp2/buffer.go index 2c95061a2c..d3c21516ee 100644 --- a/pkg/remote/trans/nphttp2/buffer.go +++ b/pkg/remote/trans/nphttp2/buffer.go @@ -145,7 +145,7 @@ func (b *buffer) ReadString(n int) (s string, err error) { panic("implement me") } -func (b *buffer) ReadBinary(n int) (p []byte, err error) { +func (b *buffer) ReadBinary(p []byte) (n int, err error) { panic("implement me") } @@ -153,7 +153,7 @@ func (b *buffer) Malloc(n int) (buf []byte, err error) { panic("implement me") } -func (b *buffer) MallocLen() (length int) { +func (b *buffer) WrittenLen() (length int) { panic("implement me") } diff --git a/pkg/remote/trans/nphttp2/server_handler_test.go b/pkg/remote/trans/nphttp2/server_handler_test.go index 79d518944b..a1a0d6f1dc 100644 --- a/pkg/remote/trans/nphttp2/server_handler_test.go +++ b/pkg/remote/trans/nphttp2/server_handler_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" @@ -63,6 +64,15 @@ func TestServerHandler(t *testing.T) { return ctx, nil }, }) + opt.SvcSearcher = mocksremote.NewMockSvcSearcher(map[string]*serviceinfo.ServiceInfo{ + "Greeter": { + Methods: map[string]serviceinfo.MethodInfo{ + "SayHello": serviceinfo.NewMethodInfo(func(ctx context.Context, handler, args, result interface{}) error { + return nil + }, func() interface{} { return nil }, func() interface{} { return nil }, false), + }, + }, + }, nil) msg := newMockNewMessage() msg.ProtocolInfoFunc = func() remote.ProtocolInfo { return remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Protobuf) @@ -95,8 +105,9 @@ func TestServerHandler(t *testing.T) { test.Assert(t, err == nil, err) // test SetInvokeHandleFunc() - svrHdl := handler.(*svrTransHandler) - svrHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + var calledInvoke int32 + handler.(remote.InvokeHandleFuncSetter).SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + atomic.StoreInt32(&calledInvoke, 1) return nil }) @@ -120,6 +131,7 @@ func TestServerHandler(t *testing.T) { time.Sleep(time.Millisecond * 50) test.Assert(t, serviceName.Load().(string) == "Greeter", serviceName.Load()) test.Assert(t, methodName.Load().(string) == "SayHello", methodName.Load()) + test.Assert(t, atomic.LoadInt32(&calledInvoke) == 1) // test OnError() handler.OnError(context.Background(), context.Canceled, npConn) diff --git a/pkg/utils/thrift.go b/pkg/utils/thrift.go index 0ce7af5eb4..dd1979b5a4 100644 --- a/pkg/utils/thrift.go +++ b/pkg/utils/thrift.go @@ -17,10 +17,10 @@ package utils import ( - "bytes" "errors" "fmt" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/apache" @@ -46,12 +46,12 @@ func (t *ThriftMessageCodec) Encode(method string, msgType athrift.TMessageType, } b := make([]byte, thrift.Binary.MessageBeginLength(method)) _ = thrift.Binary.WriteMessageBegin(b, method, thrift.TMessageType(msgType), seqID) - buf := &bytes.Buffer{} - buf.Write(b) + buf := bufiox.NewBytesWriter(&b) if err := apache.ThriftWrite(buf, msg); err != nil { return nil, err } - return buf.Bytes(), nil + _ = buf.Flush() + return b, nil } // Decode do thrift message decode, notice: msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result @@ -71,25 +71,25 @@ func (t *ThriftMessageCodec) Decode(b []byte, msg athrift.TStruct) (method strin } return } - err = apache.ThriftRead(bytes.NewBuffer(b), msg) + err = apache.ThriftRead(bufiox.NewBytesReader(b), msg) return } // Serialize serialize message into bytes. This is normal thrift serialize func. // Notice: Binary generic use Encode instead of Serialize. -func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) ([]byte, error) { - buf := &bytes.Buffer{} - if err := apache.ThriftWrite(buf, msg); err != nil { +func (t *ThriftMessageCodec) Serialize(msg athrift.TStruct) (b []byte, err error) { + buf := bufiox.NewBytesWriter(&b) + if err = apache.ThriftWrite(buf, msg); err != nil { return nil, err } - return buf.Bytes(), nil + _ = buf.Flush() + return b, nil } // Deserialize deserialize bytes into message. This is normal thrift deserialize func. // Notice: Binary generic use Decode instead of Deserialize. func (t *ThriftMessageCodec) Deserialize(msg athrift.TStruct, b []byte) (err error) { - buf := bytes.NewBuffer(b) - return apache.ThriftRead(buf, msg) + return apache.ThriftRead(bufiox.NewBytesReader(b), msg) } // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. diff --git a/pkg/utils/thrift_test.go b/pkg/utils/thrift_test.go deleted file mode 100644 index 6fea7afc0f..0000000000 --- a/pkg/utils/thrift_test.go +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright 2021 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package utils - -import ( - "errors" - "testing" - - mt "github.com/cloudwego/kitex/internal/mocks/thrift" - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - athrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" -) - -func TestRPCCodec(t *testing.T) { - rc := NewThriftMessageCodec() - - req1 := mt.NewMockReq() - req1.Msg = "Hello Kitex" - strMap := make(map[string]string) - strMap["aa"] = "aa" - strMap["bb"] = "bb" - req1.StrMap = strMap - - args1 := mt.NewMockTestArgs() - args1.Req = req1 - - // encode - buf, err := rc.Encode("mockMethod", athrift.CALL, 100, bthrift.ToApacheCodec(args1)) - test.Assert(t, err == nil, err) - - var argsDecode1 mt.MockTestArgs - // decode - method, seqID, err := rc.Decode(buf, bthrift.ToApacheCodec(&argsDecode1)) - - test.Assert(t, err == nil, err) - test.Assert(t, method == "mockMethod") - test.Assert(t, seqID == 100) - test.Assert(t, argsDecode1.Req.Msg == req1.Msg) - test.Assert(t, len(argsDecode1.Req.StrMap) == len(req1.StrMap)) - for k := range argsDecode1.Req.StrMap { - test.Assert(t, argsDecode1.Req.StrMap[k] == req1.StrMap[k]) - } - - // *** reuse ThriftMessageCodec - req2 := mt.NewMockReq() - req2.Msg = "Hello Kitex1" - strMap = make(map[string]string) - strMap["cc"] = "cc" - strMap["dd"] = "dd" - req2.StrMap = strMap - args2 := mt.NewMockTestArgs() - args2.Req = req2 - // encode - buf, err = rc.Encode("mockMethod1", athrift.CALL, 101, bthrift.ToApacheCodec(args2)) - test.Assert(t, err == nil, err) - - // decode - var argsDecode2 mt.MockTestArgs - method, seqID, err = rc.Decode(buf, bthrift.ToApacheCodec(&argsDecode2)) - - test.Assert(t, err == nil, err) - test.Assert(t, method == "mockMethod1") - test.Assert(t, seqID == 101) - test.Assert(t, argsDecode2.Req.Msg == req2.Msg) - test.Assert(t, len(argsDecode2.Req.StrMap) == len(req2.StrMap)) - for k := range argsDecode2.Req.StrMap { - test.Assert(t, argsDecode2.Req.StrMap[k] == req2.StrMap[k]) - } -} - -func TestSerializer(t *testing.T) { - rc := NewThriftMessageCodec() - - req := mt.NewMockReq() - req.Msg = "Hello Kitex" - strMap := make(map[string]string) - strMap["aa"] = "aa" - strMap["bb"] = "bb" - req.StrMap = strMap - - args := mt.NewMockTestArgs() - args.Req = req - - b, err := rc.Serialize(bthrift.ToApacheCodec(args)) - test.Assert(t, err == nil, err) - - var args2 mt.MockTestArgs - err = rc.Deserialize(bthrift.ToApacheCodec(&args2), b) - test.Assert(t, err == nil, err) - - test.Assert(t, args2.Req.Msg == req.Msg) - test.Assert(t, len(args2.Req.StrMap) == len(req.StrMap)) - for k := range args2.Req.StrMap { - test.Assert(t, args2.Req.StrMap[k] == req.StrMap[k]) - } -} - -func TestException(t *testing.T) { - errMsg := "my error" - b := MarshalError("some method", errors.New(errMsg)) - err := UnmarshalError(b) - test.Assert(t, err.Error() == errMsg, err) -} diff --git a/server/invoke.go b/server/invoke.go index 83ea70752a..6b24c9393f 100644 --- a/server/invoke.go +++ b/server/invoke.go @@ -19,7 +19,6 @@ package server // Invoker is for calling handler function wrapped by Kitex suites without connection. import ( - "context" "errors" internal_server "github.com/cloudwego/kitex/internal/server" @@ -42,42 +41,8 @@ type Invoker interface { } type tInvoker struct { + invoke.Handler *server - - h invoke.Handler -} - -// invokerMetaDecoder is used to update `PayloadLen` of `remote.Message`. -// It fixes kitex returning err when apache codec is not available due to msg.PayloadLen() == 0. -// Because users may not add transport header like transport.Framed -// to invoke.Message when calling msg.SetRequestBytes. -// This is NOT expected and it's caused by kitex design fault. -type invokerMetaDecoder struct { - remote.Codec - - d remote.MetaDecoder -} - -func (d *invokerMetaDecoder) DecodeMeta(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { - err := d.d.DecodeMeta(ctx, msg, in) - if err != nil { - return err - } - // cool ... no need to do anything. - // added transport header? - if msg.PayloadLen() > 0 { - return nil - } - // use the whole buffer - // coz for invoker remote.ByteBuffer always contains the whole msg payload - if n := in.ReadableLen(); n > 0 { - msg.SetPayloadLen(n) - } - return nil -} - -func (d *invokerMetaDecoder) DecodePayload(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { - return d.d.DecodePayload(ctx, msg, in) } // NewInvoker creates new Invoker. @@ -86,13 +51,6 @@ func NewInvoker(opts ...Option) Invoker { opt: internal_server.NewOptions(opts), svcs: newServices(), } - if codec, ok := s.opt.RemoteOpt.Codec.(remote.MetaDecoder); ok { - // see comment on type `invokerMetaDecoder` - s.opt.RemoteOpt.Codec = &invokerMetaDecoder{ - Codec: s.opt.RemoteOpt.Codec, - d: codec, - } - } s.init() return &tInvoker{ server: s, @@ -112,7 +70,7 @@ func (s *tInvoker) Init() (err error) { doAddBoundHandler(transInfoHdlr, s.server.opt.RemoteOpt) } s.Lock() - s.h, err = s.newInvokeHandler() + s.Handler, err = s.newInvokeHandler() s.Unlock() if err != nil { return err @@ -125,7 +83,7 @@ func (s *tInvoker) Init() (err error) { // Call implements the InvokeCaller interface. func (s *tInvoker) Call(msg invoke.Message) error { - return s.h.Call(msg) + return s.Handler.Call(msg) } func (s *tInvoker) newInvokeHandler() (handler invoke.Handler, err error) { diff --git a/server/invoke_test.go b/server/invoke_test.go index 30a5178df7..f5533eb2c2 100644 --- a/server/invoke_test.go +++ b/server/invoke_test.go @@ -18,6 +18,7 @@ package server import ( "context" + "encoding/binary" "strings" "sync/atomic" "testing" @@ -49,9 +50,11 @@ func TestInvokerCall(t *testing.T) { args := mocks.NewMockArgs() // call success + hl := make([]byte, 4) b, _ := thrift.MarshalFastMsg("mock", thrift.CALL, 0, args.(thrift.FastCodec)) + binary.BigEndian.PutUint32(hl, uint32(len(b))) msg := invoke.NewMessage(nil, nil) - err = msg.SetRequestBytes(b) + err = msg.SetRequestBytes(append(hl, b...)) test.Assert(t, err == nil) err = invoker.Call(msg) if err != nil { @@ -65,9 +68,11 @@ func TestInvokerCall(t *testing.T) { test.Assert(t, gotErr.Load() == nil) // call fails + hl = make([]byte, 4) b, _ = thrift.MarshalFastMsg("mockError", thrift.CALL, 0, args.(thrift.FastCodec)) + binary.BigEndian.PutUint32(hl, uint32(len(b))) msg = invoke.NewMessage(nil, nil) - err = msg.SetRequestBytes(b) + err = msg.SetRequestBytes(append(hl, b...)) test.Assert(t, err == nil) err = invoker.Call(msg) if err != nil { From 8e2005e95e674871ad6b7e7b12014420c5b1ee57 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 29 Aug 2024 16:32:12 +0800 Subject: [PATCH 61/70] chore: update dependency (#1518) --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 9df77e662e..676a371c6f 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,16 @@ go 1.18 require ( github.com/apache/thrift v0.13.0 github.com/bytedance/gopkg v0.1.1 - github.com/bytedance/sonic v1.12.1 + github.com/bytedance/sonic v1.12.2 github.com/cloudwego/configmanager v0.2.2 - github.com/cloudwego/dynamicgo v0.3.0 + github.com/cloudwego/dynamicgo v0.4.0 github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487 github.com/cloudwego/localsession v0.0.2 github.com/cloudwego/netpoll v0.6.3 github.com/cloudwego/runtimex v0.1.0 - github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 + github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55 github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 github.com/jhump/protoreflect v1.8.2 diff --git a/go.sum b/go.sum index 8a9617da8b..a3ef1a7035 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQ github.com/bytedance/gopkg v0.1.1 h1:3azzgSkiaw79u24a+w9arfH8OfnQQ4MHUt9lJFREEaE= github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic v1.12.1 h1:jWl5Qz1fy7X1ioY74WqO0KjAMtAGQs4sYnjiEBiyX24= -github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= +github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -18,8 +18,8 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/configmanager v0.2.2 h1:sVrJB8gWYTlPV2OS3wcgJSO9F2/9Zbkmcm1Z7jempOU= github.com/cloudwego/configmanager v0.2.2/go.mod h1:ppiyU+5TPLonE8qMVi/pFQk2eL3Q4P7d4hbiNJn6jwI= -github.com/cloudwego/dynamicgo v0.3.0 h1:2/jOD3cMn8YVWGmVybrn74YulmhxW8d4BPyy9pja5eo= -github.com/cloudwego/dynamicgo v0.3.0/go.mod h1:vPHEegW2xqjuDE8NAui+2D93RivFv18eWsyD9VRtORM= +github.com/cloudwego/dynamicgo v0.4.0 h1:wQqNRNiSQaLkbcn3sfpEJGZsz3xf8Il4P/3DcENsrFI= +github.com/cloudwego/dynamicgo v0.4.0/go.mod h1:zgWk2oz56EyH790LJSxrTz1j01GJBO964jJQ/y7qjJc= github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZU= github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= @@ -34,8 +34,8 @@ github.com/cloudwego/netpoll v0.6.3 h1:t+ndlwBFjQZimUj3ul31DwI45t18eOr2pcK3juZZm github.com/cloudwego/netpoll v0.6.3/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM= github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= -github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083 h1:KiEGBvsyAyUrFrpEi/e77K0SWTLK8FMHhSQ5c9kFJic= -github.com/cloudwego/thriftgo v0.3.16-0.20240805092707-81e5f6692083/go.mod h1:R4a+4aVDI0V9YCTfpNgmvbkq/9ThKgF7Om8Z0I36698= +github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55 h1:9KZEXU56Al2yaDsGGkYbjcdHtnRenxJXHJNJSLQjLME= +github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55/go.mod h1:AdLEJJVGW/ZJYvkkYAZf5SaJH+pA3OyC801WSwqcBwI= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 8fad71e22fe5c59eea3362f993d090f2ad6b4d93 Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 29 Aug 2024 16:57:29 +0800 Subject: [PATCH 62/70] chore: update dependency again (#1519) --- go.mod | 2 +- go.sum | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 676a371c6f..fbbb81c6c0 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/cloudwego/frugal v0.2.0 github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487 github.com/cloudwego/localsession v0.0.2 - github.com/cloudwego/netpoll v0.6.3 + github.com/cloudwego/netpoll v0.6.4-0.20240823082441-5c544da5550d github.com/cloudwego/runtimex v0.1.0 github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55 github.com/golang/mock v1.6.0 diff --git a/go.sum b/go.sum index a3ef1a7035..dbf562b954 100644 --- a/go.sum +++ b/go.sum @@ -3,7 +3,7 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= -github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.1.1 h1:3azzgSkiaw79u24a+w9arfH8OfnQQ4MHUt9lJFREEaE= github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= @@ -30,8 +30,8 @@ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= -github.com/cloudwego/netpoll v0.6.3 h1:t+ndlwBFjQZimUj3ul31DwI45t18eOr2pcK3juZZm+E= -github.com/cloudwego/netpoll v0.6.3/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM= +github.com/cloudwego/netpoll v0.6.4-0.20240823082441-5c544da5550d h1:G4yCpvx0Ok2TMurfrlPs44tuvroaIjwdPFQH38OzaUQ= +github.com/cloudwego/netpoll v0.6.4-0.20240823082441-5c544da5550d/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55 h1:9KZEXU56Al2yaDsGGkYbjcdHtnRenxJXHJNJSLQjLME= From e402e37ca9f888e6f9a96eb25e4fd414c91b9667 Mon Sep 17 00:00:00 2001 From: YangruiEmma Date: Thu, 29 Aug 2024 18:03:06 +0800 Subject: [PATCH 63/70] chore: remove json-iterator dependency (#1521) --- go.mod | 2 -- go.sum | 6 ------ internal/generic/thrift/http.go | 4 ++-- internal/generic/thrift/json.go | 4 ++-- pkg/utils/json_fuzz_test.go | 6 ++++-- pkg/utils/json_test.go | 34 ++++++--------------------------- 6 files changed, 14 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index fbbb81c6c0..209b67ce89 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 github.com/jhump/protoreflect v1.8.2 - github.com/json-iterator/go v1.1.12 github.com/tidwall/gjson v1.17.3 golang.org/x/net v0.24.0 golang.org/x/sync v0.8.0 @@ -39,7 +38,6 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/iancoleman/strcase v0.2.0 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index dbf562b954..599818ac6a 100644 --- a/go.sum +++ b/go.sum @@ -72,7 +72,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -82,8 +81,6 @@ github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHL github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4ua0= github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= @@ -94,8 +91,6 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 h1:uiS4zKYKJVj5F3ID+5iylfKPsEQmBEOucSD9Vgmn0i0= github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5/go.mod h1:I8AX+yW//L8Hshx6+a1m3bYkwXkpsVjA2795vP4f4oQ= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= @@ -108,7 +103,6 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/internal/generic/thrift/http.go b/internal/generic/thrift/http.go index 9f72931004..4859231161 100644 --- a/internal/generic/thrift/http.go +++ b/internal/generic/thrift/http.go @@ -21,13 +21,13 @@ import ( "fmt" "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/bytedance/sonic" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" - jsoniter "github.com/json-iterator/go" "github.com/cloudwego/kitex/pkg/generic/descriptor" ) @@ -52,7 +52,7 @@ type WriteHTTPRequest struct { var ( _ MessageWriter = (*WriteHTTPRequest)(nil) - customJson = jsoniter.Config{ + customJson = sonic.Config{ EscapeHTML: true, UseNumber: true, }.Froze() diff --git a/internal/generic/thrift/json.go b/internal/generic/thrift/json.go index 4ba24e3896..d3052bc82c 100644 --- a/internal/generic/thrift/json.go +++ b/internal/generic/thrift/json.go @@ -22,13 +22,13 @@ import ( "strconv" "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/bytedance/sonic" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" - jsoniter "github.com/json-iterator/go" "github.com/tidwall/gjson" "github.com/cloudwego/kitex/pkg/generic/descriptor" @@ -245,7 +245,7 @@ func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient boo } // resp is map - respNode, err := jsoniter.Marshal(resp) + respNode, err := sonic.Marshal(resp) if err != nil { return nil, perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("response marshal failed. err:%#v", err)) } diff --git a/pkg/utils/json_fuzz_test.go b/pkg/utils/json_fuzz_test.go index 40b5b3b720..54678d0578 100644 --- a/pkg/utils/json_fuzz_test.go +++ b/pkg/utils/json_fuzz_test.go @@ -24,12 +24,14 @@ import ( "reflect" "testing" + "github.com/bytedance/sonic" + "github.com/cloudwego/kitex/internal/test" ) func FuzzJSONStr2Map(f *testing.F) { mapInfo := prepareMap() - jsonRet, _ := jsoni.MarshalToString(mapInfo) + jsonRet, _ := sonic.MarshalString(mapInfo) f.Add(`{}`) f.Add(`{"":""}`) f.Add(jsonRet) @@ -51,7 +53,7 @@ func FuzzJSONStr2Map(f *testing.F) { func FuzzMap2JSON(f *testing.F) { mapInfo := prepareMap() - jsonRet, _ := jsoni.MarshalToString(mapInfo) + jsonRet, _ := sonic.MarshalString(mapInfo) f.Add(`{}`) f.Add(`{"":""}`) f.Add(jsonRet) diff --git a/pkg/utils/json_test.go b/pkg/utils/json_test.go index c4af6a1a07..bc752e598a 100644 --- a/pkg/utils/json_test.go +++ b/pkg/utils/json_test.go @@ -25,13 +25,11 @@ import ( "testing" "unsafe" - jsoniter "github.com/json-iterator/go" + "github.com/bytedance/sonic" "github.com/cloudwego/kitex/internal/test" ) -var jsoni = jsoniter.ConfigCompatibleWithStandardLibrary - var samples = []struct { name string m map[string]string @@ -56,7 +54,7 @@ func BenchmarkMap2JSONStr(b *testing.B) { func BenchmarkJSONStr2Map(b *testing.B) { for _, s := range samples { b.Run(s.name, func(b *testing.B) { - j, _ := jsoni.MarshalToString(s.m) + j, _ := sonic.MarshalString(s.m) _, _ = JSONStr2Map(j) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -77,7 +75,7 @@ func BenchmarkJSONMarshal(b *testing.B) { func BenchmarkJSONUnmarshal(b *testing.B) { mapInfo := prepareMap() - jsonRet, _ := jsoni.MarshalToString(mapInfo) + jsonRet, _ := sonic.MarshalString(mapInfo) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -86,26 +84,6 @@ func BenchmarkJSONUnmarshal(b *testing.B) { } } -func BenchmarkJSONIterMarshal(b *testing.B) { - mapInfo := prepareMap() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - jsoni.MarshalToString(mapInfo) - } -} - -func BenchmarkJSONIterUnmarshal(b *testing.B) { - mapInfo := prepareMap() - jsonRet, _ := Map2JSONStr(mapInfo) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - var map1 map[string]string - jsoni.UnmarshalFromString(jsonRet, &map1) - } -} - // TestMap2JSONStr test convert map to json string func TestMap2JSONStr(t *testing.T) { j, e := Map2JSONStr(map[string]string{}) @@ -238,11 +216,11 @@ func TestJSONStr2Map(t *testing.T) { test.Assert(t, htmlEscapeCharMapRet["&<>\u2028\u2029"] == "&<>\u2028\u2029", fmt.Sprintf("%+v", htmlEscapeCharMapRet)) } -// TestJSONUtil compare return between encoding/json, json-iterator and json.go +// TestJSONUtil compare return between encoding/json, sonic and json.go func TestJSONUtil(t *testing.T) { mapInfo := prepareMap() jsonRet1, _ := json.Marshal(mapInfo) - jsonRet2, _ := jsoni.MarshalToString(mapInfo) + jsonRet2, _ := sonic.MarshalString(mapInfo) jsonRet, _ := Map2JSONStr(mapInfo) jsonRet3, _ := _Map2JSONStr(mapInfo) @@ -264,7 +242,7 @@ func TestJSONUtil(t *testing.T) { } mapRetIter := make(map[string]string) - jsoni.UnmarshalFromString(string(jsonRet1), &mapRetIter) + sonic.UnmarshalString(string(jsonRet1), &mapRetIter) mapRet, err := JSONStr2Map(string(jsonRet1)) test.Assert(t, err == nil) From eed574035e1699e4ad6fa6a89ce6297492348471 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:27:09 +0800 Subject: [PATCH 64/70] chore(generic): make generic streaming APIs internal (#1522) --- client/genericclient/stream.go | 14 ++++++++------ pkg/generic/grpcjsonpb_test/generic_init.go | 4 +++- pkg/generic/grpcjsonpb_test/generic_test.go | 2 ++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/client/genericclient/stream.go b/client/genericclient/stream.go index c568e0b281..2544b3ccfb 100644 --- a/client/genericclient/stream.go +++ b/client/genericclient/stream.go @@ -14,6 +14,8 @@ * limitations under the License. */ +// NOTE: The basic generic streaming functions have been completed, but the interface needs adjustments. The feature is expected to be released later. + package genericclient import ( @@ -46,11 +48,11 @@ type BidirectionalStreaming interface { Recv() (resp interface{}, err error) } -func NewStreamingClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { - return NewStreamingClientWithServiceInfo(destService, g, StreamingServiceInfo(g), opts...) +func newStreamingClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { + return newStreamingClientWithServiceInfo(destService, g, StreamingServiceInfo(g), opts...) } -func NewStreamingClientWithServiceInfo(destService string, g generic.Generic, svcInfo *serviceinfo.ServiceInfo, opts ...client.Option) (Client, error) { +func newStreamingClientWithServiceInfo(destService string, g generic.Generic, svcInfo *serviceinfo.ServiceInfo, opts ...client.Option) (Client, error) { var options []client.Option options = append(options, client.WithGeneric(g)) options = append(options, client.WithDestService(destService)) @@ -105,7 +107,7 @@ type clientStreamingClient struct { methodInfo serviceinfo.MethodInfo } -func NewClientStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (ClientStreaming, error) { +func newClientStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (ClientStreaming, error) { gCli, ok := genericCli.(*genericServiceClient) if !ok { return nil, errors.New("invalid generic client") @@ -140,7 +142,7 @@ type serverStreamingClient struct { methodInfo serviceinfo.MethodInfo } -func NewServerStreaming(ctx context.Context, genericCli Client, method string, req interface{}, callOpts ...callopt.Option) (ServerStreaming, error) { +func newServerStreaming(ctx context.Context, genericCli Client, method string, req interface{}, callOpts ...callopt.Option) (ServerStreaming, error) { gCli, ok := genericCli.(*genericServiceClient) if !ok { return nil, errors.New("invalid generic client") @@ -176,7 +178,7 @@ type bidirectionalStreamingClient struct { methodInfo serviceinfo.MethodInfo } -func NewBidirectionalStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (BidirectionalStreaming, error) { +func newBidirectionalStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (BidirectionalStreaming, error) { gCli, ok := genericCli.(*genericServiceClient) if !ok { return nil, errors.New("invalid generic client") diff --git a/pkg/generic/grpcjsonpb_test/generic_init.go b/pkg/generic/grpcjsonpb_test/generic_init.go index 5c83f4d4bb..0d3672916f 100644 --- a/pkg/generic/grpcjsonpb_test/generic_init.go +++ b/pkg/generic/grpcjsonpb_test/generic_init.go @@ -16,6 +16,7 @@ package test +/* import ( "context" "fmt" @@ -34,7 +35,7 @@ import ( ) func newGenericClient(g generic.Generic, targetIPPort string) genericclient.Client { - cli, err := genericclient.NewStreamingClient("destService", g, + cli, err := genericclient.newStreamingClient("destService", g, client.WithTransportProtocol(transport.GRPC), client.WithHostPorts(targetIPPort), ) @@ -146,3 +147,4 @@ func (s *StreamingTestImpl) BidirectionalStreamingTest(stream mock.Mock_Bidirect wg.Wait() return } +*/ diff --git a/pkg/generic/grpcjsonpb_test/generic_test.go b/pkg/generic/grpcjsonpb_test/generic_test.go index e61750c781..6fa83ef430 100644 --- a/pkg/generic/grpcjsonpb_test/generic_test.go +++ b/pkg/generic/grpcjsonpb_test/generic_test.go @@ -16,6 +16,7 @@ package test +/* import ( "context" "fmt" @@ -161,3 +162,4 @@ func initMockTestServer(handler mock.Mock, address string) server.Server { addr, _ := net.ResolveTCPAddr("tcp", address) return newMockTestServer(handler, addr) } +*/ From 347bcf077ce2654e277584f41094d0da4b1d1eb0 Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Tue, 3 Sep 2024 10:26:51 +0800 Subject: [PATCH 65/70] fix: move json-iterator back to support marshal `map[any]any` (#1525) --- go.mod | 2 ++ go.sum | 6 ++++++ internal/generic/thrift/http.go | 1 + internal/generic/thrift/json.go | 4 ++-- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 209b67ce89..fbbb81c6c0 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 github.com/jhump/protoreflect v1.8.2 + github.com/json-iterator/go v1.1.12 github.com/tidwall/gjson v1.17.3 golang.org/x/net v0.24.0 golang.org/x/sync v0.8.0 @@ -38,6 +39,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/iancoleman/strcase v0.2.0 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 599818ac6a..dbf562b954 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -81,6 +82,8 @@ github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHL github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4ua0= github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= @@ -91,6 +94,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5 h1:uiS4zKYKJVj5F3ID+5iylfKPsEQmBEOucSD9Vgmn0i0= github.com/modern-go/gls v0.0.0-20220109145502-612d0167dce5/go.mod h1:I8AX+yW//L8Hshx6+a1m3bYkwXkpsVjA2795vP4f4oQ= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= @@ -103,6 +108,7 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/internal/generic/thrift/http.go b/internal/generic/thrift/http.go index 4859231161..5bb0d9e8a2 100644 --- a/internal/generic/thrift/http.go +++ b/internal/generic/thrift/http.go @@ -55,6 +55,7 @@ var ( customJson = sonic.Config{ EscapeHTML: true, UseNumber: true, + CopyString: true, }.Froze() ) diff --git a/internal/generic/thrift/json.go b/internal/generic/thrift/json.go index d3052bc82c..2e177fce52 100644 --- a/internal/generic/thrift/json.go +++ b/internal/generic/thrift/json.go @@ -22,7 +22,6 @@ import ( "strconv" "github.com/bytedance/gopkg/lang/dirtmake" - "github.com/bytedance/sonic" "github.com/cloudwego/dynamicgo/conv" "github.com/cloudwego/dynamicgo/conv/t2j" dthrift "github.com/cloudwego/dynamicgo/thrift" @@ -34,6 +33,7 @@ import ( "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/utils" + jsoniter "github.com/json-iterator/go" ) type JSONReaderWriter struct { @@ -245,7 +245,7 @@ func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient boo } // resp is map - respNode, err := sonic.Marshal(resp) + respNode, err := jsoniter.Marshal(resp) if err != nil { return nil, perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("response marshal failed. err:%#v", err)) } From 29a1a7909c9fc0af81b6ab5cd40a613d90b94ff4 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Tue, 3 Sep 2024 19:50:07 +0800 Subject: [PATCH 66/70] perf(grpc): bdp ping rate limit (#1527) --- pkg/remote/trans/nphttp2/grpc/bdp_estimator.go | 5 ++++- pkg/remote/trans/nphttp2/grpc/bdp_estimator_test.go | 12 ++++++++++-- pkg/remote/trans/nphttp2/grpc/http2_client.go | 4 +--- pkg/remote/trans/nphttp2/grpc/http2_server.go | 4 +--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pkg/remote/trans/nphttp2/grpc/bdp_estimator.go b/pkg/remote/trans/nphttp2/grpc/bdp_estimator.go index e4aa6fa20d..7710e9fc36 100644 --- a/pkg/remote/trans/nphttp2/grpc/bdp_estimator.go +++ b/pkg/remote/trans/nphttp2/grpc/bdp_estimator.go @@ -48,6 +48,9 @@ const ( // Easter-egg: what does the ping message say? var bdpPing = &ping{data: [8]byte{2, 4, 16, 16, 9, 14, 7, 7}} +// allow only one bdpPing per bdpPingInterval +var bdpPingInterval = time.Second + type bdpEstimator struct { // sentAt is the time when the ping was sent. sentAt time.Time @@ -94,7 +97,7 @@ func (b *bdpEstimator) add(n uint32) bool { if b.bdp == bdpLimit { return false } - if !b.isSent { + if !b.isSent && time.Since(b.sentAt) >= bdpPingInterval { b.isSent = true b.sample = n b.sentAt = time.Time{} diff --git a/pkg/remote/trans/nphttp2/grpc/bdp_estimator_test.go b/pkg/remote/trans/nphttp2/grpc/bdp_estimator_test.go index 8da2a017fa..9a1be7a457 100644 --- a/pkg/remote/trans/nphttp2/grpc/bdp_estimator_test.go +++ b/pkg/remote/trans/nphttp2/grpc/bdp_estimator_test.go @@ -18,11 +18,16 @@ package grpc import ( "testing" + "time" "github.com/cloudwego/kitex/internal/test" ) func TestBdp(t *testing.T) { + oribdpPingInterval := bdpPingInterval + defer func() { bdpPingInterval = oribdpPingInterval }() + bdpPingInterval = 10 * time.Millisecond // shorter time for testing + // init bdp estimator bdpEst := &bdpEstimator{ bdp: initialWindowSize, @@ -50,11 +55,14 @@ func TestBdp(t *testing.T) { // receive normal ping bdpEst.calculate([8]byte{0, 0, 0, 0, 0, 0, 0, 0}) + test.Assert(t, false == bdpEst.add(0)) // due to bdpPingInterval + time.Sleep(2 * bdpPingInterval) + size = 10000 // calculate 15 times for c := 0; c < 15; c++ { sent = bdpEst.add(uint32(size)) - test.Assert(t, sent) + test.Assert(t, sent, c) bdpEst.timesnap(bdpPing.data) // mock the situation that network delay is very long and data is very big @@ -65,6 +73,6 @@ func TestBdp(t *testing.T) { // receive bdp ack and calculate again bdpEst.calculate(bdpPing.data) - + bdpEst.sentAt = time.Time{} // reset for bdpPingInterval } } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 7af2571fdf..58b2eaba01 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -191,9 +191,7 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, t.initialWindowSize = opts.InitialWindowSize dynamicWindow = false } - if false && dynamicWindow { - // we force disable dynamic window here coz it's sending too many ping frames... - // and it may not work as expected when running on top of netpoll. + if dynamicWindow { t.bdpEst = &bdpEstimator{ bdp: initialWindowSize, updateFlowControl: t.updateFlowControl, diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index 25af3c4ffc..c2b84efe13 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -230,9 +230,7 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ bufferPool: newBufferPool(), } t.controlBuf = newControlBuffer(t.done) - if false && dynamicWindow { - // we force disable dynamic window here coz it's sending too many ping frames... - // and it may not work as expected when running on top of netpoll. + if dynamicWindow { t.bdpEst = &bdpEstimator{ bdp: initialWindowSize, updateFlowControl: t.updateFlowControl, From 85f9c6b8b99117b166a37000ddd1d03bc39a6f17 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:49:07 +0800 Subject: [PATCH 67/70] chore(ci): use blank identifier to fix ci check failure (#1528) --- .../genericclient/generic_stream_service.go | 2 +- client/genericclient/stream.go | 25 +++++++++++++------ internal/generic/thrift/json.go | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/client/genericclient/generic_stream_service.go b/client/genericclient/generic_stream_service.go index 7074207f2d..631509ec39 100644 --- a/client/genericclient/generic_stream_service.go +++ b/client/genericclient/generic_stream_service.go @@ -21,7 +21,7 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -func StreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo { +func streamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo { return newClientStreamingServiceInfo(g) } diff --git a/client/genericclient/stream.go b/client/genericclient/stream.go index 2544b3ccfb..6cdd54730d 100644 --- a/client/genericclient/stream.go +++ b/client/genericclient/stream.go @@ -31,25 +31,36 @@ import ( "github.com/cloudwego/kitex/pkg/streaming" ) -type ClientStreaming interface { +// NOTE: this is a temporary adjustment for ci check. remove it after fully completing the generic streaming support +var ( + _ clientStreaming = nil + _ serverStreaming = nil + _ bidirectionalStreaming = nil + _ = newStreamingClient + _ = newClientStreaming + _ = newServerStreaming + _ = newBidirectionalStreaming +) + +type clientStreaming interface { streaming.Stream Send(req interface{}) error CloseAndRecv() (resp interface{}, err error) } -type ServerStreaming interface { +type serverStreaming interface { streaming.Stream Recv() (resp interface{}, err error) } -type BidirectionalStreaming interface { +type bidirectionalStreaming interface { streaming.Stream Send(req interface{}) error Recv() (resp interface{}, err error) } func newStreamingClient(destService string, g generic.Generic, opts ...client.Option) (Client, error) { - return newStreamingClientWithServiceInfo(destService, g, StreamingServiceInfo(g), opts...) + return newStreamingClientWithServiceInfo(destService, g, streamingServiceInfo(g), opts...) } func newStreamingClientWithServiceInfo(destService string, g generic.Generic, svcInfo *serviceinfo.ServiceInfo, opts ...client.Option) (Client, error) { @@ -107,7 +118,7 @@ type clientStreamingClient struct { methodInfo serviceinfo.MethodInfo } -func newClientStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (ClientStreaming, error) { +func newClientStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (clientStreaming, error) { gCli, ok := genericCli.(*genericServiceClient) if !ok { return nil, errors.New("invalid generic client") @@ -142,7 +153,7 @@ type serverStreamingClient struct { methodInfo serviceinfo.MethodInfo } -func newServerStreaming(ctx context.Context, genericCli Client, method string, req interface{}, callOpts ...callopt.Option) (ServerStreaming, error) { +func newServerStreaming(ctx context.Context, genericCli Client, method string, req interface{}, callOpts ...callopt.Option) (serverStreaming, error) { gCli, ok := genericCli.(*genericServiceClient) if !ok { return nil, errors.New("invalid generic client") @@ -178,7 +189,7 @@ type bidirectionalStreamingClient struct { methodInfo serviceinfo.MethodInfo } -func newBidirectionalStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (BidirectionalStreaming, error) { +func newBidirectionalStreaming(ctx context.Context, genericCli Client, method string, callOpts ...callopt.Option) (bidirectionalStreaming, error) { gCli, ok := genericCli.(*genericServiceClient) if !ok { return nil, errors.New("invalid generic client") diff --git a/internal/generic/thrift/json.go b/internal/generic/thrift/json.go index 2e177fce52..4ba24e3896 100644 --- a/internal/generic/thrift/json.go +++ b/internal/generic/thrift/json.go @@ -28,12 +28,12 @@ import ( "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/thrift/base" + jsoniter "github.com/json-iterator/go" "github.com/tidwall/gjson" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/utils" - jsoniter "github.com/json-iterator/go" ) type JSONReaderWriter struct { From a3c62f28e3ad453c8b2c37238ad01893e581b347 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Thu, 5 Sep 2024 17:17:14 +0800 Subject: [PATCH 68/70] chore(generic): move generic APIs back to external package (revert) (#1531) --- internal/generic/proto/type.go | 4 +- pkg/generic/binarythrift_codec.go | 2 +- pkg/generic/generic_service.go | 4 +- pkg/generic/httppbthrift_codec.go | 4 +- pkg/generic/httppbthrift_codec_test.go | 2 +- pkg/generic/httpthrift_codec.go | 2 +- pkg/generic/httpthrift_codec_test.go | 2 +- pkg/generic/jsonpb_codec.go | 2 +- pkg/generic/jsonpb_codec_test.go | 2 +- pkg/generic/jsonthrift_codec.go | 2 +- pkg/generic/jsonthrift_codec_test.go | 2 +- pkg/generic/mapthrift_codec.go | 2 +- pkg/generic/mapthrift_codec_test.go | 2 +- pkg/generic/pb_descriptor_provider.go | 2 +- pkg/generic/pbidl_provider.go | 2 +- {internal => pkg}/generic/proto/json.go | 0 {internal => pkg}/generic/proto/json_test.go | 0 {internal => pkg}/generic/proto/protobuf.go | 0 pkg/generic/proto/type.go | 39 +++++++++++++++++++ {internal => pkg}/generic/thrift/binary.go | 0 {internal => pkg}/generic/thrift/http.go | 0 .../generic/thrift/http_fallback.go | 0 .../generic/thrift/http_go116plus_amd64.go | 0 {internal => pkg}/generic/thrift/http_pb.go | 2 +- {internal => pkg}/generic/thrift/json.go | 0 .../generic/thrift/json_fallback.go | 0 .../generic/thrift/json_go116plus_amd64.go | 0 pkg/generic/thrift/parse.go | 3 +- {internal => pkg}/generic/thrift/read.go | 0 {internal => pkg}/generic/thrift/read_test.go | 0 {internal => pkg}/generic/thrift/struct.go | 0 {internal => pkg}/generic/thrift/thrift.go | 0 {internal => pkg}/generic/thrift/util.go | 2 +- {internal => pkg}/generic/thrift/util_test.go | 8 ++-- {internal => pkg}/generic/thrift/write.go | 0 .../generic/thrift/write_test.go | 0 36 files changed, 65 insertions(+), 25 deletions(-) rename {internal => pkg}/generic/proto/json.go (100%) rename {internal => pkg}/generic/proto/json_test.go (100%) rename {internal => pkg}/generic/proto/protobuf.go (100%) create mode 100644 pkg/generic/proto/type.go rename {internal => pkg}/generic/thrift/binary.go (100%) rename {internal => pkg}/generic/thrift/http.go (100%) rename {internal => pkg}/generic/thrift/http_fallback.go (100%) rename {internal => pkg}/generic/thrift/http_go116plus_amd64.go (100%) rename {internal => pkg}/generic/thrift/http_pb.go (98%) rename {internal => pkg}/generic/thrift/json.go (100%) rename {internal => pkg}/generic/thrift/json_fallback.go (100%) rename {internal => pkg}/generic/thrift/json_go116plus_amd64.go (100%) rename {internal => pkg}/generic/thrift/read.go (100%) rename {internal => pkg}/generic/thrift/read_test.go (100%) rename {internal => pkg}/generic/thrift/struct.go (100%) rename {internal => pkg}/generic/thrift/thrift.go (100%) rename {internal => pkg}/generic/thrift/util.go (98%) rename {internal => pkg}/generic/thrift/util_test.go (95%) rename {internal => pkg}/generic/thrift/write.go (100%) rename {internal => pkg}/generic/thrift/write_test.go (100%) diff --git a/internal/generic/proto/type.go b/internal/generic/proto/type.go index 65af0b2cfa..9dfee958a0 100644 --- a/internal/generic/proto/type.go +++ b/internal/generic/proto/type.go @@ -1,5 +1,5 @@ /* - * Copyright 2021 CloudWeGo Authors + * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,12 +26,14 @@ type ( MessageDescriptor = *desc.MessageDescriptor ) +// TODO(marina.sakai): remove this type Message interface { Marshal() ([]byte, error) TryGetFieldByNumber(fieldNumber int) (interface{}, error) TrySetFieldByNumber(fieldNumber int, val interface{}) error } +// TODO(marina.sakai): modify this func NewMessage(descriptor MessageDescriptor) Message { return dynamic.NewMessage(descriptor) } diff --git a/pkg/generic/binarythrift_codec.go b/pkg/generic/binarythrift_codec.go index 5bb0828de8..8e26b5bf70 100644 --- a/pkg/generic/binarythrift_codec.go +++ b/pkg/generic/binarythrift_codec.go @@ -23,7 +23,7 @@ import ( "github.com/bytedance/gopkg/lang/dirtmake" - "github.com/cloudwego/kitex/internal/generic/thrift" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index 758c0d518b..bc184f2612 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -23,8 +23,8 @@ import ( "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift/base" - "github.com/cloudwego/kitex/internal/generic/proto" - "github.com/cloudwego/kitex/internal/generic/thrift" + "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/pkg/generic/thrift" codecProto "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/httppbthrift_codec.go b/pkg/generic/httppbthrift_codec.go index 503f6bbcf6..ea59513f81 100644 --- a/pkg/generic/httppbthrift_codec.go +++ b/pkg/generic/httppbthrift_codec.go @@ -26,9 +26,9 @@ import ( "github.com/jhump/protoreflect/desc" - "github.com/cloudwego/kitex/internal/generic/proto" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/generic/proto" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/httppbthrift_codec_test.go b/pkg/generic/httppbthrift_codec_test.go index 63025854a5..74587826ea 100644 --- a/pkg/generic/httppbthrift_codec_test.go +++ b/pkg/generic/httppbthrift_codec_test.go @@ -24,8 +24,8 @@ import ( "reflect" "testing" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/httpthrift_codec.go b/pkg/generic/httpthrift_codec.go index 3486a41cf5..e142811145 100644 --- a/pkg/generic/httpthrift_codec.go +++ b/pkg/generic/httpthrift_codec.go @@ -25,8 +25,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/httpthrift_codec_test.go b/pkg/generic/httpthrift_codec_test.go index 9f69ba5b26..1852288de7 100644 --- a/pkg/generic/httpthrift_codec_test.go +++ b/pkg/generic/httpthrift_codec_test.go @@ -24,8 +24,8 @@ import ( "github.com/bytedance/sonic" "github.com/cloudwego/dynamicgo/conv" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/jsonpb_codec.go b/pkg/generic/jsonpb_codec.go index 9cbeb735fb..ac58e21ed1 100644 --- a/pkg/generic/jsonpb_codec.go +++ b/pkg/generic/jsonpb_codec.go @@ -25,7 +25,7 @@ import ( "github.com/cloudwego/dynamicgo/conv" dproto "github.com/cloudwego/dynamicgo/proto" - "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/pkg/generic/proto" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" diff --git a/pkg/generic/jsonpb_codec_test.go b/pkg/generic/jsonpb_codec_test.go index 7778aca1cb..8913d3369b 100644 --- a/pkg/generic/jsonpb_codec_test.go +++ b/pkg/generic/jsonpb_codec_test.go @@ -23,8 +23,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" dproto "github.com/cloudwego/dynamicgo/proto" - gproto "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/internal/test" + gproto "github.com/cloudwego/kitex/pkg/generic/proto" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/jsonthrift_codec.go b/pkg/generic/jsonthrift_codec.go index 20d2b5e0f6..377d507ca1 100644 --- a/pkg/generic/jsonthrift_codec.go +++ b/pkg/generic/jsonthrift_codec.go @@ -23,8 +23,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" diff --git a/pkg/generic/jsonthrift_codec_test.go b/pkg/generic/jsonthrift_codec_test.go index 6d00d9a57a..82ff0bb1fe 100644 --- a/pkg/generic/jsonthrift_codec_test.go +++ b/pkg/generic/jsonthrift_codec_test.go @@ -21,8 +21,8 @@ import ( "github.com/cloudwego/dynamicgo/conv" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/mapthrift_codec.go b/pkg/generic/mapthrift_codec.go index 5284820c59..78e2a449e4 100644 --- a/pkg/generic/mapthrift_codec.go +++ b/pkg/generic/mapthrift_codec.go @@ -21,8 +21,8 @@ import ( "errors" "sync/atomic" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/serviceinfo" diff --git a/pkg/generic/mapthrift_codec_test.go b/pkg/generic/mapthrift_codec_test.go index 635f59963b..86630b842d 100644 --- a/pkg/generic/mapthrift_codec_test.go +++ b/pkg/generic/mapthrift_codec_test.go @@ -19,8 +19,8 @@ package generic import ( "testing" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" ) diff --git a/pkg/generic/pb_descriptor_provider.go b/pkg/generic/pb_descriptor_provider.go index 27c7e29d22..2a2e115453 100644 --- a/pkg/generic/pb_descriptor_provider.go +++ b/pkg/generic/pb_descriptor_provider.go @@ -19,7 +19,7 @@ package generic import ( dproto "github.com/cloudwego/dynamicgo/proto" - "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/pkg/generic/proto" ) // PbDescriptorProvider provide service descriptor diff --git a/pkg/generic/pbidl_provider.go b/pkg/generic/pbidl_provider.go index 587f4e9314..d115c51054 100644 --- a/pkg/generic/pbidl_provider.go +++ b/pkg/generic/pbidl_provider.go @@ -24,7 +24,7 @@ import ( dproto "github.com/cloudwego/dynamicgo/proto" "github.com/jhump/protoreflect/desc/protoparse" - "github.com/cloudwego/kitex/internal/generic/proto" + "github.com/cloudwego/kitex/pkg/generic/proto" ) type PbContentProvider struct { diff --git a/internal/generic/proto/json.go b/pkg/generic/proto/json.go similarity index 100% rename from internal/generic/proto/json.go rename to pkg/generic/proto/json.go diff --git a/internal/generic/proto/json_test.go b/pkg/generic/proto/json_test.go similarity index 100% rename from internal/generic/proto/json_test.go rename to pkg/generic/proto/json_test.go diff --git a/internal/generic/proto/protobuf.go b/pkg/generic/proto/protobuf.go similarity index 100% rename from internal/generic/proto/protobuf.go rename to pkg/generic/proto/protobuf.go diff --git a/pkg/generic/proto/type.go b/pkg/generic/proto/type.go new file mode 100644 index 0000000000..6237b2d206 --- /dev/null +++ b/pkg/generic/proto/type.go @@ -0,0 +1,39 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package proto + +import ( + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" +) + +type ( + ServiceDescriptor = *desc.ServiceDescriptor + MessageDescriptor = *desc.MessageDescriptor +) + +// Deprecated: this interface will be removed in v0.12.0 +type Message interface { + Marshal() ([]byte, error) + TryGetFieldByNumber(fieldNumber int) (interface{}, error) + TrySetFieldByNumber(fieldNumber int, val interface{}) error +} + +// Deprecated: this API will be removed in v0.12.0 +func NewMessage(descriptor MessageDescriptor) Message { + return dynamic.NewMessage(descriptor) +} diff --git a/internal/generic/thrift/binary.go b/pkg/generic/thrift/binary.go similarity index 100% rename from internal/generic/thrift/binary.go rename to pkg/generic/thrift/binary.go diff --git a/internal/generic/thrift/http.go b/pkg/generic/thrift/http.go similarity index 100% rename from internal/generic/thrift/http.go rename to pkg/generic/thrift/http.go diff --git a/internal/generic/thrift/http_fallback.go b/pkg/generic/thrift/http_fallback.go similarity index 100% rename from internal/generic/thrift/http_fallback.go rename to pkg/generic/thrift/http_fallback.go diff --git a/internal/generic/thrift/http_go116plus_amd64.go b/pkg/generic/thrift/http_go116plus_amd64.go similarity index 100% rename from internal/generic/thrift/http_go116plus_amd64.go rename to pkg/generic/thrift/http_go116plus_amd64.go diff --git a/internal/generic/thrift/http_pb.go b/pkg/generic/thrift/http_pb.go similarity index 98% rename from internal/generic/thrift/http_pb.go rename to pkg/generic/thrift/http_pb.go index 869400b1de..2ec4ff7997 100644 --- a/internal/generic/thrift/http_pb.go +++ b/pkg/generic/thrift/http_pb.go @@ -27,8 +27,8 @@ import ( "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/dynamic" - "github.com/cloudwego/kitex/internal/generic/proto" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/generic/proto" ) type HTTPPbReaderWriter struct { diff --git a/internal/generic/thrift/json.go b/pkg/generic/thrift/json.go similarity index 100% rename from internal/generic/thrift/json.go rename to pkg/generic/thrift/json.go diff --git a/internal/generic/thrift/json_fallback.go b/pkg/generic/thrift/json_fallback.go similarity index 100% rename from internal/generic/thrift/json_fallback.go rename to pkg/generic/thrift/json_fallback.go diff --git a/internal/generic/thrift/json_go116plus_amd64.go b/pkg/generic/thrift/json_go116plus_amd64.go similarity index 100% rename from internal/generic/thrift/json_go116plus_amd64.go rename to pkg/generic/thrift/json_go116plus_amd64.go diff --git a/pkg/generic/thrift/parse.go b/pkg/generic/thrift/parse.go index 5d3520ead2..7576a7953b 100644 --- a/pkg/generic/thrift/parse.go +++ b/pkg/generic/thrift/parse.go @@ -26,7 +26,6 @@ import ( "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" - "github.com/cloudwego/kitex/internal/generic/thrift" "github.com/cloudwego/kitex/pkg/generic/descriptor" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" @@ -308,7 +307,7 @@ func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor if ty, ok := cache[t.Name]; ok { return ty, nil } - typePkg, typeName := thrift.SplitType(t.Name) + typePkg, typeName := splitType(t.Name) if typePkg != "" { ref, ok := tree.GetReference(typePkg) if !ok { diff --git a/internal/generic/thrift/read.go b/pkg/generic/thrift/read.go similarity index 100% rename from internal/generic/thrift/read.go rename to pkg/generic/thrift/read.go diff --git a/internal/generic/thrift/read_test.go b/pkg/generic/thrift/read_test.go similarity index 100% rename from internal/generic/thrift/read_test.go rename to pkg/generic/thrift/read_test.go diff --git a/internal/generic/thrift/struct.go b/pkg/generic/thrift/struct.go similarity index 100% rename from internal/generic/thrift/struct.go rename to pkg/generic/thrift/struct.go diff --git a/internal/generic/thrift/thrift.go b/pkg/generic/thrift/thrift.go similarity index 100% rename from internal/generic/thrift/thrift.go rename to pkg/generic/thrift/thrift.go diff --git a/internal/generic/thrift/util.go b/pkg/generic/thrift/util.go similarity index 98% rename from internal/generic/thrift/util.go rename to pkg/generic/thrift/util.go index 5a5b020ba2..163ad70766 100644 --- a/internal/generic/thrift/util.go +++ b/pkg/generic/thrift/util.go @@ -33,7 +33,7 @@ func assertType(expected, but descriptor.Type) error { return fmt.Errorf("need %s type, but got: %s", expected, but) } -func SplitType(t string) (pkg, name string) { +func splitType(t string) (pkg, name string) { idx := strings.LastIndex(t, ".") if idx == -1 { return "", t diff --git a/internal/generic/thrift/util_test.go b/pkg/generic/thrift/util_test.go similarity index 95% rename from internal/generic/thrift/util_test.go rename to pkg/generic/thrift/util_test.go index 313bbaf912..56ea0e594d 100644 --- a/internal/generic/thrift/util_test.go +++ b/pkg/generic/thrift/util_test.go @@ -26,19 +26,19 @@ import ( ) func TestSplitType(t *testing.T) { - pkg, name := SplitType(".A") + pkg, name := splitType(".A") test.Assert(t, pkg == "") test.Assert(t, name == "A") - pkg, name = SplitType("foo.bar.A") + pkg, name = splitType("foo.bar.A") test.Assert(t, pkg == "foo.bar") test.Assert(t, name == "A") - pkg, name = SplitType("A") + pkg, name = splitType("A") test.Assert(t, pkg == "") test.Assert(t, name == "A") - pkg, name = SplitType("") + pkg, name = splitType("") test.Assert(t, pkg == "") test.Assert(t, name == "") } diff --git a/internal/generic/thrift/write.go b/pkg/generic/thrift/write.go similarity index 100% rename from internal/generic/thrift/write.go rename to pkg/generic/thrift/write.go diff --git a/internal/generic/thrift/write_test.go b/pkg/generic/thrift/write_test.go similarity index 100% rename from internal/generic/thrift/write_test.go rename to pkg/generic/thrift/write_test.go From 6f2934b4c3b4139e3b6755452d74f9802275223c Mon Sep 17 00:00:00 2001 From: Li2CO3 Date: Thu, 5 Sep 2024 19:37:44 +0800 Subject: [PATCH 69/70] chore: update dependency (#1532) --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index fbbb81c6c0..d3ebd5ad50 100644 --- a/go.mod +++ b/go.mod @@ -10,11 +10,11 @@ require ( github.com/cloudwego/dynamicgo v0.4.0 github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487 + github.com/cloudwego/gopkg v0.1.1 github.com/cloudwego/localsession v0.0.2 - github.com/cloudwego/netpoll v0.6.4-0.20240823082441-5c544da5550d + github.com/cloudwego/netpoll v0.6.4 github.com/cloudwego/runtimex v0.1.0 - github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55 + github.com/cloudwego/thriftgo v0.3.17 github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 github.com/jhump/protoreflect v1.8.2 diff --git a/go.sum b/go.sum index dbf562b954..26af1d5b78 100644 --- a/go.sum +++ b/go.sum @@ -24,18 +24,18 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487 h1:JmCA5LJYdhLY8/TfngV/DXBtu8IsLBuo0tu+dfN5iQk= -github.com/cloudwego/gopkg v0.1.1-0.20240829032745-024f019d8487/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= +github.com/cloudwego/gopkg v0.1.1 h1:UgmQ1BbiawhMoD8VjzJvwdc6Z3fuFcZR7XUibBKZ1k0= +github.com/cloudwego/gopkg v0.1.1/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= -github.com/cloudwego/netpoll v0.6.4-0.20240823082441-5c544da5550d h1:G4yCpvx0Ok2TMurfrlPs44tuvroaIjwdPFQH38OzaUQ= -github.com/cloudwego/netpoll v0.6.4-0.20240823082441-5c544da5550d/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= +github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= +github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= -github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55 h1:9KZEXU56Al2yaDsGGkYbjcdHtnRenxJXHJNJSLQjLME= -github.com/cloudwego/thriftgo v0.3.16-0.20240807060045-993835609a55/go.mod h1:AdLEJJVGW/ZJYvkkYAZf5SaJH+pA3OyC801WSwqcBwI= +github.com/cloudwego/thriftgo v0.3.17 h1:k0iQe2jEAN1WhPsXWvatwHzoxObUSX2Nw5NqdnywS8k= +github.com/cloudwego/thriftgo v0.3.17/go.mod h1:AdLEJJVGW/ZJYvkkYAZf5SaJH+pA3OyC801WSwqcBwI= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 1d26ba8392e828e51381655ff39b593f12f027f8 Mon Sep 17 00:00:00 2001 From: alice <90381261+alice-yyds@users.noreply.github.com> Date: Thu, 5 Sep 2024 19:47:31 +0800 Subject: [PATCH 70/70] chore: update version v0.11.0 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 644c709d45..96c9e79792 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package kitex // Name and Version info of this framework, used for statistics and debug const ( Name = "Kitex" - Version = "v0.10.3" + Version = "v0.11.0" )