Skip to content

Commit

Permalink
[x/programs] Improve overflow checks (#600)
Browse files Browse the repository at this point in the history
* runtime: uint64 -> int64

* ensure int32 is within bounds

* Update x/programs/runtime/dependencies.go

* sdk: add alloc overflow check
---------

Signed-off-by: Sam Liokumovich <[email protected]>
Co-authored-by: Sam Batschelet <[email protected]>
  • Loading branch information
samliok and hexfusion authored Nov 17, 2023
1 parent 67bc81b commit 283846b
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 52 deletions.
20 changes: 10 additions & 10 deletions x/programs/examples/counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ func TestCounterProgram(t *testing.T) {
// create counter for alice on program 1
result, err := rt.Call(ctx, "initialize_address", programIDPtr, alicePtr)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

// validate counter at 0
result, err = rt.Call(ctx, "get_value", programIDPtr, alicePtr)
require.NoError(err)
require.Equal(uint64(0), result[0])
require.Equal(int64(0), result[0])

// initialize second runtime to create second counter program with an empty
// meter.
Expand Down Expand Up @@ -102,16 +102,16 @@ func TestCounterProgram(t *testing.T) {
// initialize counter for alice on runtime 2
result, err = rt2.Call(ctx, "initialize_address", programID2Ptr, alicePtr2)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

// increment alice's counter on program 2 by 10
result, err = rt2.Call(ctx, "inc", programID2Ptr, alicePtr2, 10)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

result, err = rt2.Call(ctx, "get_value", programID2Ptr, alicePtr2)
require.NoError(err)
require.Equal(uint64(10), result[0])
require.Equal(int64(10), result[0])

// stop the runtime to prevent further execution
rt2.Stop()
Expand All @@ -127,13 +127,13 @@ func TestCounterProgram(t *testing.T) {
// increment alice's counter on program 1
result, err = rt.Call(ctx, "inc", programIDPtr, alicePtr, 1)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

result, err = rt.Call(ctx, "get_value", programIDPtr, alicePtr)
require.NoError(err)

log.Debug("count program 1",
zap.Uint64("alice", result[0]),
zap.Int64("alice", result[0]),
)

// write program id 2 to stack of program 1
Expand All @@ -142,17 +142,17 @@ func TestCounterProgram(t *testing.T) {

caller := programIDPtr
target := programID2Ptr
maxUnitsProgramToProgram := uint64(10000)
maxUnitsProgramToProgram := int64(10000)

// increment alice's counter on program 2
result, err = rt.Call(ctx, "inc_external", caller, target, maxUnitsProgramToProgram, alicePtr, 5)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

// expect alice's counter on program 2 to be 15
result, err = rt.Call(ctx, "get_value_external", caller, target, maxUnitsProgramToProgram, alicePtr)
require.NoError(err)
require.Equal(uint64(15), result[0])
require.Equal(int64(15), result[0])

require.Greater(rt.Meter().GetBalance(), uint64(0))
}
6 changes: 3 additions & 3 deletions x/programs/examples/imports/program/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ func (i *Import) callProgramFn(
return int64(res[0])
}

func getCallArgs(ctx context.Context, memory runtime.Memory, buffer []byte, invokeProgramID uint64) ([]uint64, error) {
func getCallArgs(ctx context.Context, memory runtime.Memory, buffer []byte, invokeProgramID int64) ([]int64, error) {
// first arg contains id of program to call
args := []uint64{invokeProgramID}
args := []int64{invokeProgramID}
p := codec.NewReader(buffer, len(buffer))
i := 0
for !p.Empty() {
size := p.UnpackInt64(true)
isInt := p.UnpackBool()
if isInt {
valueInt := p.UnpackUint64(true)
valueInt := p.UnpackInt64(true)
args = append(args, valueInt)
} else {
valueBytes := make([]byte, size)
Expand Down
28 changes: 14 additions & 14 deletions x/programs/examples/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ func (t *Token) Run(ctx context.Context) error {
}

t.log.Debug("init response",
zap.Uint64("init", resp[0]),
zap.Int64("init", resp[0]),
)

result, err := rt.Call(ctx, "get_total_supply", programIDPtr)
if err != nil {
return err
}
t.log.Debug("total supply",
zap.Uint64("minted", result[0]),
zap.Int64("minted", result[0]),
)

// generate alice keys
Expand Down Expand Up @@ -116,13 +116,13 @@ func (t *Token) Run(ctx context.Context) error {
)

// mint 100 tokens to alice
mintAlice := uint64(1000)
mintAlice := int64(1000)
_, err = rt.Call(ctx, "mint_to", programIDPtr, alicePtr, mintAlice)
if err != nil {
return err
}
t.log.Debug("minted",
zap.Uint64("alice", mintAlice),
zap.Int64("alice", mintAlice),
)

// check balance of alice
Expand All @@ -131,7 +131,7 @@ func (t *Token) Run(ctx context.Context) error {
return err
}
t.log.Debug("balance",
zap.Uint64("alice", result[0]),
zap.Int64("alice", result[0]),
)

// check balance of bob
Expand All @@ -140,27 +140,27 @@ func (t *Token) Run(ctx context.Context) error {
return err
}
t.log.Debug("balance",
zap.Uint64("bob", result[0]),
zap.Int64("bob", result[0]),
)

// transfer 50 from alice to bob
transferToBob := uint64(50)
transferToBob := int64(50)
_, err = rt.Call(ctx, "transfer", programIDPtr, alicePtr, bobPtr, transferToBob)
if err != nil {
return err
}
t.log.Debug("transferred",
zap.Uint64("alice", transferToBob),
zap.Uint64("to bob", transferToBob),
zap.Int64("alice", transferToBob),
zap.Int64("to bob", transferToBob),
)

_, err = rt.Call(ctx, "transfer", programIDPtr, alicePtr, bobPtr, 1)
if err != nil {
return err
}
t.log.Debug("transferred",
zap.Uint64("alice", transferToBob),
zap.Uint64("to bob", transferToBob),
zap.Int64("alice", transferToBob),
zap.Int64("to bob", transferToBob),
)

// get balance alice
Expand All @@ -169,15 +169,15 @@ func (t *Token) Run(ctx context.Context) error {
return err
}
t.log.Debug("balance",
zap.Uint64("alice", result[0]),
zap.Int64("alice", result[0]),
)

// get balance bob
result, err = rt.Call(ctx, "get_balance", programIDPtr, bobPtr)
if err != nil {
return err
}
t.log.Debug("balance", zap.Uint64("bob", result[0]))
t.log.Debug("balance", zap.Int64("bob", result[0]))

t.log.Debug("remaining balance",
zap.Uint64("unit", rt.Meter().GetBalance()),
Expand Down Expand Up @@ -221,7 +221,7 @@ func (t *Token) RunShort(ctx context.Context) error {
}

t.log.Debug("init response",
zap.Uint64("init", resp[0]),
zap.Int64("init", resp[0]),
)
return nil
}
8 changes: 4 additions & 4 deletions x/programs/examples/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ import (
"github.com/ava-labs/hypersdk/x/programs/runtime"
)

func newKeyPtr(ctx context.Context, key ed25519.PublicKey, runtime runtime.Runtime) (uint64, error) {
ptr, err := runtime.Memory().Alloc(ed25519.PublicKeyLen)
func newKeyPtr(ctx context.Context, key ed25519.PublicKey, rt runtime.Runtime) (int64, error) {
ptr, err := rt.Memory().Alloc(ed25519.PublicKeyLen)
if err != nil {
return 0, err
}

// write programID to memory which we will later pass to the program
err = runtime.Memory().Write(ptr, key[:])
err = rt.Memory().Write(ptr, key[:])
if err != nil {
return 0, err
}

return ptr, err
return int64(ptr), err
}

func newKey() (ed25519.PrivateKey, ed25519.PublicKey, error) {
Expand Down
4 changes: 3 additions & 1 deletion x/programs/runtime/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

package runtime

import "github.com/ava-labs/avalanchego/utils/units"
import (
"github.com/ava-labs/avalanchego/utils/units"
)

const (
AllocFnName = "alloc"
Expand Down
6 changes: 4 additions & 2 deletions x/programs/runtime/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ type Runtime interface {
// units. The engine will handle the compile strategy and instantiate the
// module with the given imports. Initialize should only be called once.
Initialize(context.Context, []byte, uint64) error
// Call invokes the an exported guest function with the given parameters.
Call(context.Context, string, ...uint64) ([]uint64, error)
// Call invokes an exported guest function with the given parameters.
// Returns the results of the call or an error if the call failed.
// If the function called does not return a result this value is set to nil.
Call(context.Context, string, ...int64) ([]int64, error)
// Memory returns the runtime memory.
Memory() Memory
// Meter returns the runtime meter.
Expand Down
1 change: 1 addition & 0 deletions x/programs/runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var (
ErrInsufficientUnits = errors.New("insufficient units")
ErrRuntimeStoreSet = errors.New("runtime store has already been set")
ErrNegativeValue = errors.New("negative value")
ErrIntegerConversionOverflow = errors.New("integer overflow during conversion")

// Trap errors
ErrTrapStackOverflow = errors.New("the current stack space was exhausted")
Expand Down
21 changes: 15 additions & 6 deletions x/programs/runtime/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package runtime

import (
"fmt"
"math"
"runtime"
)

Expand Down Expand Up @@ -78,6 +79,11 @@ func (m *memory) Alloc(length uint64) (uint64, error) {
if err != nil {
return 0, err
}

if length > math.MaxInt32 {
return 0, fmt.Errorf("failed to allocate memory: %w", ErrIntegerConversionOverflow)
}

result, err := fn.Call(m.client.Store(), int32(length))
if err != nil {
return 0, handleTrapError(err)
Expand Down Expand Up @@ -111,7 +117,7 @@ func (m *memory) Len() (uint64, error) {

// WriteBytes is a helper function that allocates memory and writes the given
// bytes to the memory returning the offset.
func WriteBytes(m Memory, buf []byte) (uint64, error) {
func WriteBytes(m Memory, buf []byte) (int64, error) {
offset, err := m.Alloc(uint64(len(buf)))
if err != nil {
return 0, err
Expand All @@ -121,7 +127,7 @@ func WriteBytes(m Memory, buf []byte) (uint64, error) {
return 0, err
}

return offset, nil
return int64(offset), nil
}

// CallParam defines a value to be passed to a guest function.
Expand All @@ -131,8 +137,8 @@ type CallParam struct {

// WriteParams is a helper function that writes the given params to memory if non integer.
// Supported types include int, uint64 and string.
func WriteParams(m Memory, p []CallParam) ([]uint64, error) {
params := []uint64{}
func WriteParams(m Memory, p []CallParam) ([]int64, error) {
params := []int64{}
for _, param := range p {
switch v := param.Value.(type) {
case string:
Expand All @@ -145,9 +151,12 @@ func WriteParams(m Memory, p []CallParam) ([]uint64, error) {
if v < 0 {
return nil, fmt.Errorf("failed to write param: %w", ErrNegativeValue)
}
params = append(params, uint64(v))
params = append(params, int64(v))
case uint64:
params = append(params, v)
if v > math.MaxInt64 {
return nil, fmt.Errorf("failed to write param: %w", ErrIntegerConversionOverflow)
}
params = append(params, int64(v))
default:
return nil, fmt.Errorf("%w: support types int, uint64 and string", ErrInvalidParamType)
}
Expand Down
21 changes: 14 additions & 7 deletions x/programs/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func getRegisteredImportModules(importTypes []*wasmtime.ImportType) []string {
return imports
}

func (r *WasmRuntime) Call(_ context.Context, name string, params ...uint64) ([]uint64, error) {
func (r *WasmRuntime) Call(_ context.Context, name string, params ...int64) ([]int64, error) {
var fnName string
switch name {
case AllocFnName, DeallocFnName, MemoryFnName:
Expand Down Expand Up @@ -184,11 +184,14 @@ func (r *WasmRuntime) Call(_ context.Context, name string, params ...uint64) ([]

switch v := result.(type) {
case int32:
value := uint64(result.(int32))
return []uint64{value}, nil
value := int64(result.(int32))
return []int64{value}, nil
case int64:
value := uint64(result.(int64))
return []uint64{value}, nil
value := result.(int64)
return []int64{value}, nil
case nil:
// the function had no return values
return nil, nil
default:
return nil, fmt.Errorf("invalid result type: %v", v)
}
Expand Down Expand Up @@ -239,14 +242,18 @@ func PreCompileWasmBytes(programBytes []byte, cfg *Config) ([]byte, error) {
}

// mapFunctionParams maps call input to the expected wasm function params.
func mapFunctionParams(input []uint64, values []*wasmtime.ValType) ([]interface{}, error) {
func mapFunctionParams(input []int64, values []*wasmtime.ValType) ([]interface{}, error) {
params := make([]interface{}, len(values))
for i, v := range values {
switch v.Kind() {
case wasmtime.KindI32:
// ensure this value is within the range of an int32
if !EnsureInt64ToInt32(input[i]) {
return nil, fmt.Errorf("%w: %d", ErrIntegerConversionOverflow, input[i])
}
params[i] = int32(input[i])
case wasmtime.KindI64:
params[i] = int64(input[i])
params[i] = input[i]
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidParamType, v.Kind())
}
Expand Down
6 changes: 3 additions & 3 deletions x/programs/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ func TestCallParams(t *testing.T) {
err = runtime.Initialize(ctx, wasm, maxUnits)
require.NoError(err)

resp, err := runtime.Call(ctx, "add", uint64(10), uint64(10))
resp, err := runtime.Call(ctx, "add", 10, 10)
require.NoError(err)
require.Equal(uint64(20), resp[0])
require.Equal(int64(20), resp[0])

// pass 3 params when 2 are expected.
_, err = runtime.Call(ctx, "add", uint64(10), uint64(10), uint64(10))
_, err = runtime.Call(ctx, "add", 10, 10, 10)
require.ErrorIs(err, ErrInvalidParamCount)
}
10 changes: 10 additions & 0 deletions x/programs/runtime/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package runtime

import "math"

func EnsureInt64ToInt32(v int64) bool {
return v >= math.MinInt32 && v <= math.MaxInt32
}
Loading

0 comments on commit 283846b

Please sign in to comment.