Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(thrift): new pkg for deprecating apache #6

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions protocol/thrift/apache/apache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Package apache contains code for working with apache thrift indirectly
//
// It acts as a bridge between generated code which relies on apache codec like:
//
// Write(p thrift.TProtocol) error
// Read(p thrift.TProtocol) error
//
// and kitex ecosystem.
//
// Because we're deprecating apache thrift, all kitex ecosystem code will not rely on apache thrift
// except one pkg: `github.com/cloudwego/kitex/pkg/protocol/bthrift`. Why is the package chosen?
// All legacy generated code relies on it, and we may not be able to update the code in a brief timeframe.
// So the package is chosen to register `thrift.NewTBinaryProtocol` to this package in order to use it
// without importing `github.com/apache/thrift`
//
// ThriftRead or ThriftWrite is implemented for calling Read/Write
// without knowing the interface of `thrift.TProtocol`.
// Since we already have `thrift.NewTBinaryProtocol`, we only need to check:
// if the return value of `thrift.NewTBinaryProtocol` implements
// the input which is `thrift.TProtocol` of Read/Write
//
// For new generated code,
// it no longer uses the `github.com/cloudwego/kitex/pkg/protocol/bthrift`
package apache

import (
"errors"
"fmt"
"reflect"
)

var (
newTBinaryProtocol reflect.Value

rvTrue = reflect.ValueOf(true) // for calling NewTBinaryProtocol
)

var (
ttransportType = reflect.TypeOf((*TTransport)(nil)).Elem()
errorType = reflect.TypeOf((*error)(nil)).Elem()
)

var (
errNoNewTBinaryProtocol = errors.New("thrift.NewTBinaryProtocol method not registered. Make sure you're using apache/thrift == 0.13.0 and clouwdwego/kitex >= 0.11.0")
errNotPointer = errors.New("input not pointer")
errNoReadMethod = errors.New("thrift.TStruct `Read` method not found")
errNoWriteMethod = errors.New("thrift.TStruct `Write` method not found")

errMethodType = errors.New("method type not match")
errNewFuncType = errors.New("function type not match")
)

func errNewFuncTypeNotMatch(t reflect.Type) error {
const expect = "func(thrift.TTransport, bool, bool) *thrift.TBinaryProtocol"
return fmt.Errorf("%w:\n\texpect: %s\n\t got: %s", errNewFuncType, expect, t)
}

func errReadWriteMethodNotMatch(t reflect.Type) error {
const expect = "func(thrift.TProtocol) error"
return fmt.Errorf("%w:\n\texpect: %s\n\t got: %s", errMethodType, expect, t)
}

// RegisterNewTBinaryProtocol accepts `thrift.NewTBinaryProtocol` func and save it for later use.
func RegisterNewTBinaryProtocol(fn interface{}) error {
v := reflect.ValueOf(fn)
t := v.Type()

// check it's func
if t.Kind() != reflect.Func {
return errNewFuncTypeNotMatch(t)
}

// check "func(thrift.TTransport, bool, bool) *thrift.TBinaryProtocol"
// can also check with t.String() instead of field by field?
if t.NumIn() != 3 ||
!t.In(0).Implements(ttransportType) ||
t.In(1).Kind() != reflect.Bool ||
t.In(2).Kind() != reflect.Bool {
return errNewFuncTypeNotMatch(t)
}
if t.NumOut() != 1 {
// not checking if it's thrift.TProtocol
// but in ThriftRead/ThriftWrite, we will check if it implements the input of Read/Write
// so we can make it easier to test.
return errNewFuncTypeNotMatch(t)
}
newTBinaryProtocol = v
return nil
}

func checkThriftReadWriteFuncType(t reflect.Type) error {
if !newTBinaryProtocol.IsValid() {
return errNoNewTBinaryProtocol
}

// checks `func(thrift.TProtocol) error`
if t.NumIn() != 1 || t.In(0).Kind() != reflect.Interface ||
!newTBinaryProtocol.Type().Out(0).Implements(t.In(0)) {
return errReadWriteMethodNotMatch(t)
}
if t.NumOut() != 1 ||
!t.Out(0).Implements(errorType) {
return errReadWriteMethodNotMatch(t)
}
return nil
}

// ThriftRead calls Read method of v.
//
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftRead(t TTransport, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
// Read/Write method is always pointer receiver
return errNotPointer
}
rfunc := rv.MethodByName("Read")

// check Read func signature: func(thrift.TProtocol) error
if !rfunc.IsValid() || rfunc.Kind() != reflect.Func {
return errNoReadMethod
}
if err := checkThriftReadWriteFuncType(rfunc.Type()); err != nil {
return err
}

// iprot := NewTBinaryProtocol(t, true, true)
iprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0]

// err := v.Read(iprot)
err := rfunc.Call([]reflect.Value{iprot})[0]
if err.IsNil() {
return nil
}
return err.Interface().(error)
}

// ThriftWrite calls Write method of v.
//
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftWrite(t TTransport, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
// Read/Write method is always pointer receiver
return errNotPointer
}
wfunc := rv.MethodByName("Write")

// check Write func signature: func(thrift.TProtocol) error
if !wfunc.IsValid() || wfunc.Kind() != reflect.Func {
return errNoWriteMethod
}
if err := checkThriftReadWriteFuncType(wfunc.Type()); err != nil {
return err
}

// oprot := NewTBinaryProtocol(t, true, true)
oprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0]

// err := v.Write(oprot)
err := wfunc.Call([]reflect.Value{oprot})[0]
if err.IsNil() {
return nil
}
return err.Interface().(error)
}
142 changes: 142 additions & 0 deletions protocol/thrift/apache/apache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package apache

import (
"bytes"
"encoding/json"
"errors"
"io"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRegisterNewTBinaryProtocol(t *testing.T) {
{ // case: not func type
fn := 1
err := RegisterNewTBinaryProtocol(fn)
t.Log(err)
assert.ErrorIs(t, err, errNewFuncType)
}

{ // case: args err
fn := func(_ TTransport, _ bool, _ int) {}
err := RegisterNewTBinaryProtocol(fn)
t.Log(err)
assert.ErrorIs(t, err, errNewFuncType)
}

{ // case: ret err
fn := func(_ TTransport, _, _ bool) {}
err := RegisterNewTBinaryProtocol(fn)
t.Log(err)
assert.ErrorIs(t, err, errNewFuncType)
}

{ // case: no err
fn := func(_ TTransport, _, _ bool) error { return nil }
err := RegisterNewTBinaryProtocol(fn)
assert.NoError(t, err)
assert.True(t, newTBinaryProtocol.IsValid())
newTBinaryProtocol = reflect.Value{} // reset
}
}

type TestingWriteRead struct {
Msg string

mockErr error
}

func (t *TestingWriteRead) Read(r io.Reader) error {
if t.mockErr != nil {
return t.mockErr
}
return json.NewDecoder(r).Decode(t)
}

func (t *TestingWriteRead) Write(w io.Writer) error {
if t.mockErr != nil {
return t.mockErr
}
return json.NewEncoder(w).Encode(t)
}

func TestThriftWriteRead(t *testing.T) {
called := 0
fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer {
assert.True(t, b0)
assert.True(t, b1)
called++
return trans.(BufferTransport).Buffer
}
err := RegisterNewTBinaryProtocol(fn)
require.NoError(t, err)
defer func() { newTBinaryProtocol = reflect.Value{} }()

buf := &bytes.Buffer{}
p0 := &TestingWriteRead{Msg: "hello"}
err = ThriftWrite(BufferTransport{buf}, p0) // calls p0.Write
require.NoError(t, err)
require.Equal(t, 1, called)

p1 := &TestingWriteRead{}
err = ThriftRead(BufferTransport{buf}, p1) // calls p1.Read
require.NoError(t, err)
require.Equal(t, 2, called)
require.Equal(t, p0, p1)
}

type TestingWriteReadMethodNotMatch struct{}

func (p *TestingWriteReadMethodNotMatch) Read(v bool) error { return nil }
func (p *TestingWriteReadMethodNotMatch) Write(v bool) error { return nil }

func TestThriftWriteReadErr(t *testing.T) {
var err error

// errNotPointer
p := TestingWriteRead{Msg: "hello"}
err = ThriftWrite(BufferTransport{nil}, p)
assert.Same(t, err, errNotPointer)
err = ThriftRead(BufferTransport{nil}, p)
assert.Same(t, err, errNotPointer)

// errNoNewTBinaryProtocol
err = ThriftWrite(BufferTransport{nil}, &p)
assert.Same(t, err, errNoNewTBinaryProtocol)

// Read/Write returns err
fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { return nil }
RegisterNewTBinaryProtocol(fn)
defer func() { newTBinaryProtocol = reflect.Value{} }()
p.mockErr = errors.New("mock")
err = ThriftWrite(BufferTransport{nil}, &p)
assert.Same(t, err, p.mockErr)
err = ThriftRead(BufferTransport{nil}, &p)
assert.Same(t, err, p.mockErr)

// errMethodType
p1 := TestingWriteReadMethodNotMatch{}
err = ThriftWrite(BufferTransport{nil}, &p1)
assert.ErrorIs(t, err, errMethodType)
err = ThriftRead(BufferTransport{nil}, &p1)
assert.ErrorIs(t, err, errMethodType)
}
45 changes: 45 additions & 0 deletions protocol/thrift/apache/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package apache

import (
"bytes"
"context"
"io"
)

// TTransport is identical with thrift.TTransport.
type TTransport interface {
io.ReadWriteCloser
RemainingBytes() (num_bytes uint64)
Flush(ctx context.Context) (err error)
Open() error
IsOpen() bool
}

// BufferTransport extends bytes.Buffer to support TTransport
type BufferTransport struct {
*bytes.Buffer
}

func (p BufferTransport) IsOpen() bool { return true }
func (p BufferTransport) Open() error { return nil }
func (p BufferTransport) Close() error { p.Reset(); return nil }
func (p BufferTransport) Flush(_ context.Context) error { return nil }
func (p BufferTransport) RemainingBytes() uint64 { return uint64(p.Len()) }

var _ TTransport = BufferTransport{nil}
Loading
Loading