Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(thrift): skipdecoder #5

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions protocol/thrift/binaryreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ import (
"sync"
)

type nextIface interface {
Next(n int) ([]byte, error)
}

type discardIface interface {
Discard(n int) (int, error)
}

// BinaryReader represents a reader for binary protocol
type BinaryReader struct {
r nextIface
Expand All @@ -53,8 +45,7 @@ func NewBinaryReader(r io.Reader) *BinaryReader {
if nextr, ok := r.(nextIface); ok {
ret.r = nextr
} else {
nextr := poolNextReader.Get().(*nextReader)
nextr.Reset(r)
nextr := newNextReader(r)
ret.r = nextr
ret.d = nextr
}
Expand All @@ -65,7 +56,7 @@ func NewBinaryReader(r io.Reader) *BinaryReader {
func (r *BinaryReader) Release() {
nextr, ok := r.r.(*nextReader)
if ok {
poolNextReader.Put(nextr)
nextr.Release()
}
r.reset()
poolBinaryReader.Put(r)
Expand Down
165 changes: 165 additions & 0 deletions protocol/thrift/skipdecoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"encoding/binary"
"fmt"
"io"
"sync"
)

var poolSkipDecoder = sync.Pool{
New: func() interface{} {
return &SkipDecoder{}
},
}

// SkipDecoder scans the underlying io.Reader and returns the bytes of a type
type SkipDecoder struct {
p skipReaderIface
}

// NewSkipDecoder ... call Release if no longer use
func NewSkipDecoder(r io.Reader) *SkipDecoder {
p := poolSkipDecoder.Get().(*SkipDecoder)
p.Reset(r)
return p
}

// Reset ...
func (p *SkipDecoder) Reset(r io.Reader) {
// fast path without returning to pool if remote.ByteBuffer && *skipByteBuffer
if buf, ok := r.(remoteByteBuffer); ok {
if p.p != nil {
r, ok := p.p.(*skipByteBuffer)
if ok {
r.Reset(buf)
return
}
p.p.Release()
}
p.p = newSkipByteBuffer(buf)
return
}

// not remote.ByteBuffer

if p.p != nil {
p.p.Release()
}
p.p = newSkipReader(r)
}

// Release ...
func (p *SkipDecoder) Release() {
p.p.Release()
p.p = nil
poolSkipDecoder.Put(p)
}

// Next skips a specific type and returns its bytes
func (p *SkipDecoder) Next(t TType) (buf []byte, err error) {
if err := p.skip(t, defaultRecursionDepth); err != nil {
return nil, err
}
return p.p.Bytes()
}

func (p *SkipDecoder) skip(t TType, maxdepth int) error {
if maxdepth == 0 {
return errDepthLimitExceeded
}
if sz := typeToSize[t]; sz > 0 {
_, err := p.p.Next(int(sz))
return err
}
switch t {
case STRING:
b, err := p.p.Next(4)
if err != nil {
return err
}
sz := int(binary.BigEndian.Uint32(b))
if sz < 0 {
return errNegativeSize
}
if _, err := p.p.Next(sz); err != nil {
return err
}
case STRUCT:
for {
b, err := p.p.Next(1) // TType
if err != nil {
return err
}
tp := TType(b[0])
if tp == STOP {
break
}
if _, err := p.p.Next(2); err != nil { // Field ID
return err
}
if err := p.skip(tp, maxdepth-1); err != nil {
return err
}
}
case MAP:
b, err := p.p.Next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len
if err != nil {
return err
}
kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:]))
if sz < 0 {
return errNegativeSize
}
ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt])
if ksz > 0 && vsz > 0 {
_, err := p.p.Next(int(sz) * (ksz + vsz))
return err
}
for i := int32(0); i < sz; i++ {
if err := p.skip(kt, maxdepth-1); err != nil {
return err
}
if err := p.skip(vt, maxdepth-1); err != nil {
return err
}
}
case SET, LIST:
b, err := p.p.Next(5) // 1 byte value type, 4 bytes Len
if err != nil {
return err
}
vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:]))
if sz < 0 {
return errNegativeSize
}
if vsz := typeToSize[vt]; vsz > 0 {
_, err := p.p.Next(int(sz) * int(vsz))
return err
}
for i := int32(0); i < sz; i++ {
if err := p.skip(vt, maxdepth-1); err != nil {
return err
}
}
default:
return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t))
}
return nil
}
175 changes: 175 additions & 0 deletions protocol/thrift/skipdecoder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"bytes"
"math/rand"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestSkipDecoder(t *testing.T) {
x := BinaryProtocol{}
// byte
b := x.AppendByte([]byte(nil), 1)
sz0 := len(b)

// string
b = x.AppendString(b, strings.Repeat("hello", 500)) // larger than buffer
sz1 := len(b)

// list<i32>
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
sz2 := len(b)

// list<string>
b = x.AppendListBegin(b, STRING, 1)
b = x.AppendString(b, "hello")
sz3 := len(b)

// list<list<i32>>
b = x.AppendListBegin(b, LIST, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
sz4 := len(b)

// map<i32, i64>
b = x.AppendMapBegin(b, I32, I64, 1)
b = x.AppendI32(b, 1)
b = x.AppendI64(b, 2)
sz5 := len(b)

// map<i32, string>
b = x.AppendMapBegin(b, I32, STRING, 1)
b = x.AppendI32(b, 1)
b = x.AppendString(b, "hello")
sz6 := len(b)

// map<string, i64>
b = x.AppendMapBegin(b, STRING, I64, 1)
b = x.AppendString(b, "hello")
b = x.AppendI64(b, 2)
sz7 := len(b)

// map<i32, list<i32>>
b = x.AppendMapBegin(b, I32, LIST, 1)
b = x.AppendI32(b, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
sz8 := len(b)

// map<list<i32>, i32>
b = x.AppendMapBegin(b, LIST, I32, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
b = x.AppendI32(b, 1)
sz9 := len(b)

// struct i32, list<i32>
b = x.AppendFieldBegin(b, I32, 1)
b = x.AppendI32(b, 1)
b = x.AppendFieldBegin(b, LIST, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
b = x.AppendFieldStop(b)
sz10 := len(b)

r := NewSkipDecoder(bytes.NewReader(b))
defer r.Release()

readn := 0
b, err := r.Next(BYTE) // byte
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz0, readn)
b, err = r.Next(STRING) // string
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz1, readn)
b, err = r.Next(LIST) // list<i32>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz2, readn)
b, err = r.Next(LIST) // list<string>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz3, readn)
b, err = r.Next(LIST) // list<list<i32>>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz4, readn)
b, err = r.Next(MAP) // map<i32, i64>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz5, readn)
b, err = r.Next(MAP) // map<i32, string>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz6, readn)
b, err = r.Next(MAP) // map<string, i64>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz7, readn)
b, err = r.Next(MAP) // map<i32, list<i32>>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz8, readn)
b, err = r.Next(MAP) // map<list<i32>, i32>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz9, readn)
b, err = r.Next(STRUCT) // struct i32, list<i32>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz10, readn)

{ // other cases
// errDepthLimitExceeded
b = b[:0]
for i := 0; i < defaultRecursionDepth+1; i++ {
b = x.AppendFieldBegin(b, STRUCT, 1)
}
r := NewSkipDecoder(bytes.NewReader(b))
_, err := r.Next(STRUCT)
require.Same(t, errDepthLimitExceeded, err)

// unknown type
_, err = r.Next(TType(122))
require.Error(t, err)
}
}

func TestSkipDecoderReset(t *testing.T) {
x := BinaryProtocol{}
b := x.AppendString([]byte(nil), "hello")

r := NewSkipDecoder(nil)
for i := 0; i < 10; i++ {
if rand.Intn(2) == 1 { // random skipreader to test Reset
r.Reset(&remoteByteBufferImplForT{b: b})
} else {
r.Reset(bytes.NewReader(b))
}
retb, err := r.Next(STRING)
require.NoError(t, err)
require.Equal(t, b, retb)
}
}
Loading
Loading