Skip to content

Commit

Permalink
[bcs / type tag] Fix some bugs, add error handling, full testing (#28)
Browse files Browse the repository at this point in the history
* [bcs] Add deserialization errors

This adds errors when deserializing a type isn't possible due to too few bytes

* [bcs] Add full test coverage and bugfixes for BCS package

* [typetag] Add documentation, missing types, and cut down extra data

There were some holder pieces of data in the TypeTags that weren't used
since TypeTags are just a representation of a type.  Those were removed
and tests added to properly cover to 89% of lines in the TypeTags.

Additionally, some helper functions were added for option, string and
object types.
  • Loading branch information
gregnazario authored May 17, 2024
1 parent 1cb7125 commit 4862fec
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 128 deletions.
174 changes: 171 additions & 3 deletions bcs/bcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ func (st *TestStruct) UnmarshalBCS(bcs *Deserializer) {
st.b = bcs.Bool()
}

type TestStruct2 struct {
num uint8
b bool
}

func (st TestStruct2) MarshalBCS(bcs *Serializer) {
bcs.U8(st.num)
bcs.Bool(st.b)
}
func (st TestStruct2) UnmarshalBCS(bcs *Deserializer) {
st.num = bcs.U8()
st.b = bcs.Bool()
}

func Test_U8(t *testing.T) {
serialized := []string{"00", "01", "ff"}
deserialized := []uint8{0, 1, 0xff}
Expand Down Expand Up @@ -173,14 +187,168 @@ func Test_Struct(t *testing.T) {
// Deserializer
for i, input := range serialized {
bytes, _ := hex.DecodeString(input)
deserializer := &Deserializer{source: bytes}
deserializer := NewDeserializer(bytes)
st := TestStruct{}
deserializer.Struct(&st)
assert.Equal(t, deserialized[i], st)
assert.NoError(t, deserializer.Error())
}
}

func Test_DeserializeSequence(t *testing.T) {
deserialized := []TestStruct{{0, false}, {5, true}, {255, true}}
serialized := []byte{0x03, 0x00, 0x00, 0x05, 0x01, 0xFF, 0x01}

ser := &Serializer{}
SerializeSequence(deserialized, ser)
assert.NoError(t, ser.Error())
actualSerialized := ser.ToBytes()
assert.Equal(t, serialized, actualSerialized)

des := NewDeserializer(actualSerialized)
actualDeserialized := DeserializeSequence[TestStruct](des)
assert.NoError(t, ser.Error())
assert.Equal(t, deserialized, actualDeserialized)
}

func Test_InvalidBool(t *testing.T) {
des := NewDeserializer([]byte{0x02})
des.Bool()
assert.Error(t, des.Error())
}

func Test_InvalidBytes(t *testing.T) {
des := NewDeserializer([]byte{0x02})
des.ReadBytes()
assert.Error(t, des.Error())
}

func Test_InvalidFixedBytesInto(t *testing.T) {
des := NewDeserializer([]byte{0x02})
bytes := make([]byte, 2)
des.ReadFixedBytesInto(bytes)
assert.Error(t, des.Error())
}

func Test_DoubleSetError(t *testing.T) {
des := NewDeserializer([]byte{0x02})
des.setError("first error")
des.setError("second error")
assert.Equal(t, "first error", des.Error().Error())
}

func Test_SerializeSequence(t *testing.T) {
// Test not implementing Marshal
ser := Serializer{}
SerializeSequence([]byte{0x00}, &ser)
assert.Error(t, ser.Error())

// Test by reference
testStruct := TestStruct{
num: 22,
b: true,
}
data := []TestStruct{testStruct}
ser = Serializer{}
SerializeSequence(data, &ser)
assert.NoError(t, ser.Error())
assert.True(t, len(ser.ToBytes()) != 0)

// Test reset
ser.Reset()
assert.True(t, len(ser.ToBytes()) == 0)

// Test by value
testStruct2 := TestStruct2{
num: 52,
b: false,
}
data2 := []TestStruct2{testStruct2}
SerializeSequence(data2, &ser)
assert.NoError(t, ser.Error())
}

func Test_DeserializeSequenceError(t *testing.T) {
// Test no leading size byte
des := NewDeserializer([]byte{})
DeserializeSequence[TestStruct](des)
assert.Error(t, des.Error())

// Test no bytes for struct
des = NewDeserializer([]byte{0x01})
DeserializeSequence[TestStruct](des)
assert.Error(t, des.Error())

// Test not a struct type to deserialize
des = NewDeserializer([]byte{0x01})
DeserializeSequence[uint8](des)
assert.Error(t, des.Error())
}

func Test_DeserializerErrors(t *testing.T) {
serialized, _ := hex.DecodeString("000100FF")
des := NewDeserializer(serialized)
assert.Equal(t, 4, des.Remaining())
assert.Equal(t, uint8(0), des.U8())
assert.Equal(t, 3, des.Remaining())
assert.Equal(t, uint16(1), des.U16())
assert.Equal(t, 1, des.Remaining())
des.U16()
assert.Error(t, des.Error())
des.SetError(nil)
assert.Equal(t, uint8(0xff), des.U8())
assert.NoError(t, des.Error())

des.Bool()
assert.Error(t, des.Error())
des.SetError(nil)
des.ReadFixedBytes(2)
assert.Error(t, des.Error())
des.SetError(nil)
des.U16()
assert.Error(t, des.Error())
des.SetError(nil)
des.U32()
assert.Error(t, des.Error())
des.SetError(nil)
des.U64()
assert.Error(t, des.Error())
des.SetError(nil)
des.U128()
assert.Error(t, des.Error())
des.SetError(nil)
des.U256()
assert.Error(t, des.Error())
des.SetError(nil)
des.U256()
assert.Error(t, des.Error())
des.SetError(nil)
des.Uleb128()
assert.Error(t, des.Error())
des.SetError(nil)
des.ReadBytes()
assert.Error(t, des.Error())
des.SetError(nil)
des.U8()
assert.Error(t, des.Error())
}

func Test_ConvenienceFunctions(t *testing.T) {
str := TestStruct{
num: 10,
b: true,
}

bytes, err := Serialize(&str)
assert.NoError(t, err)

str2 := TestStruct{}
err = Deserialize(&str2, bytes)
assert.NoError(t, err)

assert.Equal(t, str, str2)
}

func helper[TYPE uint8 | uint16 | uint32 | uint64 | bool | []byte | string](t *testing.T, serialized []string, deserialized []TYPE, serialize func(serializer *Serializer, val TYPE), deserialize func(deserializer *Deserializer) TYPE) {

// Serializer
Expand All @@ -195,7 +363,7 @@ func helper[TYPE uint8 | uint16 | uint32 | uint64 | bool | []byte | string](t *t
// Deserializer
for i, input := range serialized {
bytes, _ := hex.DecodeString(input)
deserializer := &Deserializer{source: bytes}
deserializer := NewDeserializer(bytes)
assert.Equal(t, deserialized[i], deserialize(deserializer))
assert.NoError(t, deserializer.Error())
}
Expand All @@ -216,7 +384,7 @@ func helperBigInt(t *testing.T, serialized []string, deserialized []*big.Int, se
// Deserializer
for i, input := range serialized {
bytes, _ := hex.DecodeString(input)
deserializer := &Deserializer{source: bytes}
deserializer := NewDeserializer(bytes)
actual := deserialize(deserializer)
assert.NoError(t, deserializer.Error())
assert.Equal(t, 0, deserialized[i].Cmp(&actual))
Expand Down
55 changes: 50 additions & 5 deletions bcs/deserializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,47 +52,72 @@ func (des *Deserializer) Remaining() int {
// Bool deserializes a single byte as a bool
func (des *Deserializer) Bool() bool {
out := false
switch des.source[des.pos] {
if des.pos >= len(des.source) {
des.setError("not enough bytes remaining to deserialize bool")
return out
}

switch des.U8() {
case 0:
out = false
case 1:
out = true
default:
des.setError("bad bool at [%des]: %x", des.pos, des.source[des.pos])
des.setError("bad bool at [%des]: %x", des.pos-1, des.source[des.pos-1])
}
return out
}

// U8 deserializes a single unsigned 8-bit integer
func (des *Deserializer) U8() uint8 {
if des.pos >= len(des.source) {
des.setError("not enough bytes remaining to deserialize u8")
return 0
}
out := des.source[des.pos]
des.pos++
return out
}

// U16 deserializes a single unsigned 16-bit integer
func (des *Deserializer) U16() uint16 {
if des.pos+1 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize u16")
return 0
}
out := binary.LittleEndian.Uint16(des.source[des.pos : des.pos+2])
des.pos += 2
return out
}

// U32 deserializes a single unsigned 32-bit integer
func (des *Deserializer) U32() uint32 {
if des.pos+3 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize u32")
return 0
}
out := binary.LittleEndian.Uint32(des.source[des.pos : des.pos+4])
des.pos += 4
return out
}

// U64 deserializes a single unsigned 64-bit integer
func (des *Deserializer) U64() uint64 {
if des.pos+7 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize u64")
return 0
}
out := binary.LittleEndian.Uint64(des.source[des.pos : des.pos+8])
des.pos += 8
return out
}

// U128 deserializes a single unsigned 128-bit integer
func (des *Deserializer) U128() big.Int {
if des.pos+15 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize u128")
return *big.NewInt(-1)
}
var bytesBigEndian [16]byte
copy(bytesBigEndian[:], des.source[des.pos:des.pos+16])
des.pos += 16
Expand All @@ -104,6 +129,10 @@ func (des *Deserializer) U128() big.Int {

// U256 deserializes a single unsigned 256-bit integer
func (des *Deserializer) U256() big.Int {
if des.pos+31 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize u256")
return *big.NewInt(-1)
}
var bytesBigEndian [32]byte
copy(bytesBigEndian[:], des.source[des.pos:des.pos+32])
des.pos += 32
Expand All @@ -119,6 +148,11 @@ func (des *Deserializer) Uleb128() uint32 {
shift := 0

for {
if des.pos >= len(des.source) {
des.setError("not enough bytes remaining to deserialize uleb128")
return 0
}

val := des.source[des.pos]
out = out | (uint32(val&0x7f) << shift)
des.pos++
Expand All @@ -138,6 +172,10 @@ func (des *Deserializer) ReadBytes() []byte {
if des.err != nil {
return nil
}
if des.pos+int(length)-1 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize bytes")
return nil
}
out := make([]byte, length)
copy(out, des.source[des.pos:des.pos+int(length)])
des.pos += int(length)
Expand All @@ -151,15 +189,22 @@ func (des *Deserializer) ReadString() string {

// ReadFixedBytes reads bytes not-prefixed with a length
func (des *Deserializer) ReadFixedBytes(length int) []byte {
if des.pos+length-1 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize fixedBytes")
return nil
}
out := make([]byte, length)
copy(out, des.source[des.pos:des.pos+length])
des.pos += length
des.ReadFixedBytesInto(out)
return out
}

// ReadFixedBytesInto reads bytes not-prefixed with a length into a byte array
func (des *Deserializer) ReadFixedBytesInto(dest []byte) {
length := len(dest)
if des.pos+length-1 >= len(des.source) {
des.setError("not enough bytes remaining to deserialize fixedBytes")
return
}
copy(dest, des.source[des.pos:des.pos+length])
des.pos += length
}
Expand All @@ -182,7 +227,7 @@ func DeserializeSequence[T any](des *Deserializer) []T {
if ok {
mv.UnmarshalBCS(des)
} else {
des.SetError(fmt.Errorf("could not deserialize sequence[%d] member of %T", i, v))
des.setError("could not deserialize sequence[%d] member of %T", i, v)
return nil
}
}
Expand Down
4 changes: 2 additions & 2 deletions bcs/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ func (ser *Serializer) SetError(err error) {
// Bool serialize a bool into a single byte
func (ser *Serializer) Bool(v bool) {
if v {
ser.out.WriteByte(1)
ser.U8(1)
} else {
ser.out.WriteByte(0)
ser.U8(0)
}
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/goclient/goclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func main() {
},
Function: "transfer",
// ArgTypes: []aptos.TypeTag{
// aptos.TypeTag{Value: &aptos.AccountAddressTag{Value: dest}},
// aptos.TypeTag{Value: &aptos.AddressTag{Value: dest}},
// aptos.TypeTag{Value: &aptos.U64Tag{Value: amount}},
// },
ArgTypes: []aptos.TypeTag{},
Expand Down
2 changes: 1 addition & 1 deletion nodeClient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ func TestPollForTransaction(t *testing.T) {
dt := time.Now().Sub(start)

assert.GreaterOrEqual(t, dt, 9*time.Millisecond)
assert.Less(t, dt, 15*time.Millisecond)
assert.Less(t, dt, 20*time.Millisecond)
assert.Error(t, err)
}
Loading

0 comments on commit 4862fec

Please sign in to comment.