From 283846bac51de111d18c225c8b0a7b92b6cf6ec8 Mon Sep 17 00:00:00 2001 From: Sam Liokumovich <65994425+samliok@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:15:33 -0500 Subject: [PATCH] [x/programs] Improve overflow checks (#600) * runtime: uint64 -> int64 * ensure int32 is within bounds * Update x/programs/runtime/dependencies.go * sdk: add alloc overflow check --------- Signed-off-by: Sam Liokumovich <65994425+samliok@users.noreply.github.com> Co-authored-by: Sam Batschelet --- x/programs/examples/counter_test.go | 20 ++++++------- .../examples/imports/program/program.go | 6 ++-- x/programs/examples/token.go | 28 +++++++++---------- x/programs/examples/util.go | 8 +++--- x/programs/runtime/consts.go | 4 ++- x/programs/runtime/dependencies.go | 6 ++-- x/programs/runtime/errors.go | 1 + x/programs/runtime/memory.go | 21 ++++++++++---- x/programs/runtime/runtime.go | 21 +++++++++----- x/programs/runtime/runtime_test.go | 6 ++-- x/programs/runtime/util.go | 10 +++++++ x/programs/rust/wasmlanche_sdk/src/memory.rs | 8 ++++-- 12 files changed, 87 insertions(+), 52 deletions(-) create mode 100644 x/programs/runtime/util.go diff --git a/x/programs/examples/counter_test.go b/x/programs/examples/counter_test.go index bf6b9e5773..2dd9ac1ce6 100644 --- a/x/programs/examples/counter_test.go +++ b/x/programs/examples/counter_test.go @@ -66,12 +66,12 @@ func TestCounterProgram(t *testing.T) { // create counter for alice on program 1 result, err := rt.Call(ctx, "initialize_address", programIDPtr, alicePtr) require.NoError(err) - require.Equal(uint64(1), result[0]) + require.Equal(int64(1), result[0]) // validate counter at 0 result, err = rt.Call(ctx, "get_value", programIDPtr, alicePtr) require.NoError(err) - require.Equal(uint64(0), result[0]) + require.Equal(int64(0), result[0]) // initialize second runtime to create second counter program with an empty // meter. @@ -102,16 +102,16 @@ func TestCounterProgram(t *testing.T) { // initialize counter for alice on runtime 2 result, err = rt2.Call(ctx, "initialize_address", programID2Ptr, alicePtr2) require.NoError(err) - require.Equal(uint64(1), result[0]) + require.Equal(int64(1), result[0]) // increment alice's counter on program 2 by 10 result, err = rt2.Call(ctx, "inc", programID2Ptr, alicePtr2, 10) require.NoError(err) - require.Equal(uint64(1), result[0]) + require.Equal(int64(1), result[0]) result, err = rt2.Call(ctx, "get_value", programID2Ptr, alicePtr2) require.NoError(err) - require.Equal(uint64(10), result[0]) + require.Equal(int64(10), result[0]) // stop the runtime to prevent further execution rt2.Stop() @@ -127,13 +127,13 @@ func TestCounterProgram(t *testing.T) { // increment alice's counter on program 1 result, err = rt.Call(ctx, "inc", programIDPtr, alicePtr, 1) require.NoError(err) - require.Equal(uint64(1), result[0]) + require.Equal(int64(1), result[0]) result, err = rt.Call(ctx, "get_value", programIDPtr, alicePtr) require.NoError(err) log.Debug("count program 1", - zap.Uint64("alice", result[0]), + zap.Int64("alice", result[0]), ) // write program id 2 to stack of program 1 @@ -142,17 +142,17 @@ func TestCounterProgram(t *testing.T) { caller := programIDPtr target := programID2Ptr - maxUnitsProgramToProgram := uint64(10000) + maxUnitsProgramToProgram := int64(10000) // increment alice's counter on program 2 result, err = rt.Call(ctx, "inc_external", caller, target, maxUnitsProgramToProgram, alicePtr, 5) require.NoError(err) - require.Equal(uint64(1), result[0]) + require.Equal(int64(1), result[0]) // expect alice's counter on program 2 to be 15 result, err = rt.Call(ctx, "get_value_external", caller, target, maxUnitsProgramToProgram, alicePtr) require.NoError(err) - require.Equal(uint64(15), result[0]) + require.Equal(int64(15), result[0]) require.Greater(rt.Meter().GetBalance(), uint64(0)) } diff --git a/x/programs/examples/imports/program/program.go b/x/programs/examples/imports/program/program.go index b33ac22bd5..a5f349ea24 100644 --- a/x/programs/examples/imports/program/program.go +++ b/x/programs/examples/imports/program/program.go @@ -170,16 +170,16 @@ func (i *Import) callProgramFn( return int64(res[0]) } -func getCallArgs(ctx context.Context, memory runtime.Memory, buffer []byte, invokeProgramID uint64) ([]uint64, error) { +func getCallArgs(ctx context.Context, memory runtime.Memory, buffer []byte, invokeProgramID int64) ([]int64, error) { // first arg contains id of program to call - args := []uint64{invokeProgramID} + args := []int64{invokeProgramID} p := codec.NewReader(buffer, len(buffer)) i := 0 for !p.Empty() { size := p.UnpackInt64(true) isInt := p.UnpackBool() if isInt { - valueInt := p.UnpackUint64(true) + valueInt := p.UnpackInt64(true) args = append(args, valueInt) } else { valueBytes := make([]byte, size) diff --git a/x/programs/examples/token.go b/x/programs/examples/token.go index ebfa212227..5c41e39e04 100644 --- a/x/programs/examples/token.go +++ b/x/programs/examples/token.go @@ -71,7 +71,7 @@ func (t *Token) Run(ctx context.Context) error { } t.log.Debug("init response", - zap.Uint64("init", resp[0]), + zap.Int64("init", resp[0]), ) result, err := rt.Call(ctx, "get_total_supply", programIDPtr) @@ -79,7 +79,7 @@ func (t *Token) Run(ctx context.Context) error { return err } t.log.Debug("total supply", - zap.Uint64("minted", result[0]), + zap.Int64("minted", result[0]), ) // generate alice keys @@ -116,13 +116,13 @@ func (t *Token) Run(ctx context.Context) error { ) // mint 100 tokens to alice - mintAlice := uint64(1000) + mintAlice := int64(1000) _, err = rt.Call(ctx, "mint_to", programIDPtr, alicePtr, mintAlice) if err != nil { return err } t.log.Debug("minted", - zap.Uint64("alice", mintAlice), + zap.Int64("alice", mintAlice), ) // check balance of alice @@ -131,7 +131,7 @@ func (t *Token) Run(ctx context.Context) error { return err } t.log.Debug("balance", - zap.Uint64("alice", result[0]), + zap.Int64("alice", result[0]), ) // check balance of bob @@ -140,18 +140,18 @@ func (t *Token) Run(ctx context.Context) error { return err } t.log.Debug("balance", - zap.Uint64("bob", result[0]), + zap.Int64("bob", result[0]), ) // transfer 50 from alice to bob - transferToBob := uint64(50) + transferToBob := int64(50) _, err = rt.Call(ctx, "transfer", programIDPtr, alicePtr, bobPtr, transferToBob) if err != nil { return err } t.log.Debug("transferred", - zap.Uint64("alice", transferToBob), - zap.Uint64("to bob", transferToBob), + zap.Int64("alice", transferToBob), + zap.Int64("to bob", transferToBob), ) _, err = rt.Call(ctx, "transfer", programIDPtr, alicePtr, bobPtr, 1) @@ -159,8 +159,8 @@ func (t *Token) Run(ctx context.Context) error { return err } t.log.Debug("transferred", - zap.Uint64("alice", transferToBob), - zap.Uint64("to bob", transferToBob), + zap.Int64("alice", transferToBob), + zap.Int64("to bob", transferToBob), ) // get balance alice @@ -169,7 +169,7 @@ func (t *Token) Run(ctx context.Context) error { return err } t.log.Debug("balance", - zap.Uint64("alice", result[0]), + zap.Int64("alice", result[0]), ) // get balance bob @@ -177,7 +177,7 @@ func (t *Token) Run(ctx context.Context) error { if err != nil { return err } - t.log.Debug("balance", zap.Uint64("bob", result[0])) + t.log.Debug("balance", zap.Int64("bob", result[0])) t.log.Debug("remaining balance", zap.Uint64("unit", rt.Meter().GetBalance()), @@ -221,7 +221,7 @@ func (t *Token) RunShort(ctx context.Context) error { } t.log.Debug("init response", - zap.Uint64("init", resp[0]), + zap.Int64("init", resp[0]), ) return nil } diff --git a/x/programs/examples/util.go b/x/programs/examples/util.go index 2fa1d84ace..547d2b3e52 100644 --- a/x/programs/examples/util.go +++ b/x/programs/examples/util.go @@ -13,19 +13,19 @@ import ( "github.com/ava-labs/hypersdk/x/programs/runtime" ) -func newKeyPtr(ctx context.Context, key ed25519.PublicKey, runtime runtime.Runtime) (uint64, error) { - ptr, err := runtime.Memory().Alloc(ed25519.PublicKeyLen) +func newKeyPtr(ctx context.Context, key ed25519.PublicKey, rt runtime.Runtime) (int64, error) { + ptr, err := rt.Memory().Alloc(ed25519.PublicKeyLen) if err != nil { return 0, err } // write programID to memory which we will later pass to the program - err = runtime.Memory().Write(ptr, key[:]) + err = rt.Memory().Write(ptr, key[:]) if err != nil { return 0, err } - return ptr, err + return int64(ptr), err } func newKey() (ed25519.PrivateKey, ed25519.PublicKey, error) { diff --git a/x/programs/runtime/consts.go b/x/programs/runtime/consts.go index 1e696b78d3..5f0886b214 100644 --- a/x/programs/runtime/consts.go +++ b/x/programs/runtime/consts.go @@ -3,7 +3,9 @@ package runtime -import "github.com/ava-labs/avalanchego/utils/units" +import ( + "github.com/ava-labs/avalanchego/utils/units" +) const ( AllocFnName = "alloc" diff --git a/x/programs/runtime/dependencies.go b/x/programs/runtime/dependencies.go index 5e0800135e..1145b5f6ec 100644 --- a/x/programs/runtime/dependencies.go +++ b/x/programs/runtime/dependencies.go @@ -29,8 +29,10 @@ type Runtime interface { // units. The engine will handle the compile strategy and instantiate the // module with the given imports. Initialize should only be called once. Initialize(context.Context, []byte, uint64) error - // Call invokes the an exported guest function with the given parameters. - Call(context.Context, string, ...uint64) ([]uint64, error) + // Call invokes an exported guest function with the given parameters. + // Returns the results of the call or an error if the call failed. + // If the function called does not return a result this value is set to nil. + Call(context.Context, string, ...int64) ([]int64, error) // Memory returns the runtime memory. Memory() Memory // Meter returns the runtime meter. diff --git a/x/programs/runtime/errors.go b/x/programs/runtime/errors.go index 01606d3015..0e446a06a3 100644 --- a/x/programs/runtime/errors.go +++ b/x/programs/runtime/errors.go @@ -21,6 +21,7 @@ var ( ErrInsufficientUnits = errors.New("insufficient units") ErrRuntimeStoreSet = errors.New("runtime store has already been set") ErrNegativeValue = errors.New("negative value") + ErrIntegerConversionOverflow = errors.New("integer overflow during conversion") // Trap errors ErrTrapStackOverflow = errors.New("the current stack space was exhausted") diff --git a/x/programs/runtime/memory.go b/x/programs/runtime/memory.go index a30dc161bd..997b4790a1 100644 --- a/x/programs/runtime/memory.go +++ b/x/programs/runtime/memory.go @@ -5,6 +5,7 @@ package runtime import ( "fmt" + "math" "runtime" ) @@ -78,6 +79,11 @@ func (m *memory) Alloc(length uint64) (uint64, error) { if err != nil { return 0, err } + + if length > math.MaxInt32 { + return 0, fmt.Errorf("failed to allocate memory: %w", ErrIntegerConversionOverflow) + } + result, err := fn.Call(m.client.Store(), int32(length)) if err != nil { return 0, handleTrapError(err) @@ -111,7 +117,7 @@ func (m *memory) Len() (uint64, error) { // WriteBytes is a helper function that allocates memory and writes the given // bytes to the memory returning the offset. -func WriteBytes(m Memory, buf []byte) (uint64, error) { +func WriteBytes(m Memory, buf []byte) (int64, error) { offset, err := m.Alloc(uint64(len(buf))) if err != nil { return 0, err @@ -121,7 +127,7 @@ func WriteBytes(m Memory, buf []byte) (uint64, error) { return 0, err } - return offset, nil + return int64(offset), nil } // CallParam defines a value to be passed to a guest function. @@ -131,8 +137,8 @@ type CallParam struct { // WriteParams is a helper function that writes the given params to memory if non integer. // Supported types include int, uint64 and string. -func WriteParams(m Memory, p []CallParam) ([]uint64, error) { - params := []uint64{} +func WriteParams(m Memory, p []CallParam) ([]int64, error) { + params := []int64{} for _, param := range p { switch v := param.Value.(type) { case string: @@ -145,9 +151,12 @@ func WriteParams(m Memory, p []CallParam) ([]uint64, error) { if v < 0 { return nil, fmt.Errorf("failed to write param: %w", ErrNegativeValue) } - params = append(params, uint64(v)) + params = append(params, int64(v)) case uint64: - params = append(params, v) + if v > math.MaxInt64 { + return nil, fmt.Errorf("failed to write param: %w", ErrIntegerConversionOverflow) + } + params = append(params, int64(v)) default: return nil, fmt.Errorf("%w: support types int, uint64 and string", ErrInvalidParamType) } diff --git a/x/programs/runtime/runtime.go b/x/programs/runtime/runtime.go index 4a6fbb41d6..9525b55719 100644 --- a/x/programs/runtime/runtime.go +++ b/x/programs/runtime/runtime.go @@ -152,7 +152,7 @@ func getRegisteredImportModules(importTypes []*wasmtime.ImportType) []string { return imports } -func (r *WasmRuntime) Call(_ context.Context, name string, params ...uint64) ([]uint64, error) { +func (r *WasmRuntime) Call(_ context.Context, name string, params ...int64) ([]int64, error) { var fnName string switch name { case AllocFnName, DeallocFnName, MemoryFnName: @@ -184,11 +184,14 @@ func (r *WasmRuntime) Call(_ context.Context, name string, params ...uint64) ([] switch v := result.(type) { case int32: - value := uint64(result.(int32)) - return []uint64{value}, nil + value := int64(result.(int32)) + return []int64{value}, nil case int64: - value := uint64(result.(int64)) - return []uint64{value}, nil + value := result.(int64) + return []int64{value}, nil + case nil: + // the function had no return values + return nil, nil default: return nil, fmt.Errorf("invalid result type: %v", v) } @@ -239,14 +242,18 @@ func PreCompileWasmBytes(programBytes []byte, cfg *Config) ([]byte, error) { } // mapFunctionParams maps call input to the expected wasm function params. -func mapFunctionParams(input []uint64, values []*wasmtime.ValType) ([]interface{}, error) { +func mapFunctionParams(input []int64, values []*wasmtime.ValType) ([]interface{}, error) { params := make([]interface{}, len(values)) for i, v := range values { switch v.Kind() { case wasmtime.KindI32: + // ensure this value is within the range of an int32 + if !EnsureInt64ToInt32(input[i]) { + return nil, fmt.Errorf("%w: %d", ErrIntegerConversionOverflow, input[i]) + } params[i] = int32(input[i]) case wasmtime.KindI64: - params[i] = int64(input[i]) + params[i] = input[i] default: return nil, fmt.Errorf("%w: %v", ErrInvalidParamType, v.Kind()) } diff --git a/x/programs/runtime/runtime_test.go b/x/programs/runtime/runtime_test.go index 4324638720..5d77746484 100644 --- a/x/programs/runtime/runtime_test.go +++ b/x/programs/runtime/runtime_test.go @@ -70,11 +70,11 @@ func TestCallParams(t *testing.T) { err = runtime.Initialize(ctx, wasm, maxUnits) require.NoError(err) - resp, err := runtime.Call(ctx, "add", uint64(10), uint64(10)) + resp, err := runtime.Call(ctx, "add", 10, 10) require.NoError(err) - require.Equal(uint64(20), resp[0]) + require.Equal(int64(20), resp[0]) // pass 3 params when 2 are expected. - _, err = runtime.Call(ctx, "add", uint64(10), uint64(10), uint64(10)) + _, err = runtime.Call(ctx, "add", 10, 10, 10) require.ErrorIs(err, ErrInvalidParamCount) } diff --git a/x/programs/runtime/util.go b/x/programs/runtime/util.go new file mode 100644 index 0000000000..3693b50976 --- /dev/null +++ b/x/programs/runtime/util.go @@ -0,0 +1,10 @@ +// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package runtime + +import "math" + +func EnsureInt64ToInt32(v int64) bool { + return v >= math.MinInt32 && v <= math.MaxInt32 +} diff --git a/x/programs/rust/wasmlanche_sdk/src/memory.rs b/x/programs/rust/wasmlanche_sdk/src/memory.rs index d515281cbf..69d94b4b92 100644 --- a/x/programs/rust/wasmlanche_sdk/src/memory.rs +++ b/x/programs/rust/wasmlanche_sdk/src/memory.rs @@ -84,16 +84,20 @@ pub unsafe fn deallocate(ptr: *mut u8, capacity: usize) { } /* memory functions ------------------------------------------- */ -// https://radu-matei.com/blog/practical-guide-to-wasm-memory/ - /// Allocate memory into the instance of Program and return the offset to the /// start of the block. +/// # Panics +/// Panics if the pointer exceeds the maximum size of an i64. #[no_mangle] pub extern "C" fn alloc(len: usize) -> *mut u8 { // create a new mutable buffer with capacity `len` let mut buf = Vec::with_capacity(len); // take a mutable pointer to the buffer let ptr = buf.as_mut_ptr(); + // ensure memory pointer is fits in an i64 + // to avoid potential issues when passing + // across wasm boundary + assert!(i64::try_from(ptr as u64).is_ok()); // take ownership of the memory block and // ensure that its destructor is not // called when the object goes out of scope