diff --git a/internal/polkavm/host_call/accumulate_functions.go b/internal/polkavm/host_call/accumulate_functions.go index 7f04d6f7..57a60931 100644 --- a/internal/polkavm/host_call/accumulate_functions.go +++ b/internal/polkavm/host_call/accumulate_functions.go @@ -13,12 +13,12 @@ import ( "github.com/eigerco/strawberry/internal/state" ) -// Empower ΩE(ϱ, ω, μ, (x, y)) -func Empower(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error) { - if gas < EmpowerCost { +// Bless ΩB(ϱ, ω, μ, (x, y)) +func Bless(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error) { + if gas < BlessCost { return gas, regs, mem, ctxPair, ErrOutOfGas } - gas -= EmpowerCost + gas -= BlessCost // let [m, a, v, o, n] = ω7...12 managerServiceId, assignServiceId, designateServiceId, addr, servicesNr := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4] @@ -108,9 +108,8 @@ func Checkpoint(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPa ctxPair.ExceptionalCtx = ctxPair.RegularCtx - // Split the new ϱ' value into its lower and upper parts. + // Set the new ϱ' value into ω′7 regs[A0] = uint32(gas & ((1 << 32) - 1)) - regs[A1] = uint32(gas >> 32) return gas, regs, mem, ctxPair, nil } @@ -122,19 +121,14 @@ func New(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Ga } gas -= NewCost - // let [o, l, gl, gh, ml, mh] = ω7..13 - addr, preimageLength, gl, gh, ml, mh := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4], regs[A5] + // let [o, l, g, m] = ω7..11 + addr, preimageLength, gasLimitAccumulator, gasLimitTransfer := regs[A0], regs[A1], regs[A2], regs[A3] // c = μo⋅⋅⋅+32 if No⋅⋅⋅+32 ⊂ Vμ otherwise ∇ codeHashBytes := make([]byte, 32) if err := mem.Read(addr, codeHashBytes); err != nil { return gas, withCode(regs, OOB), mem, ctxPair, nil } - // let g = 2^32 ⋅ gh + gl - gasLimitAccumulator := uint64(gh)<<32 | uint64(gl) - - // let m = 2^32 ⋅ mh + ml - gasLimitTransfer := uint64(mh)<<32 | uint64(ml) codeHash := crypto.Hash(codeHashBytes) @@ -145,8 +139,8 @@ func New(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Ga {Hash: codeHash, Length: service.PreimageLength(preimageLength)}: {}, }, CodeHash: codeHash, - GasLimitForAccumulator: gasLimitAccumulator, - GasLimitOnTransfer: gasLimitTransfer, + GasLimitForAccumulator: uint64(gasLimitAccumulator), + GasLimitOnTransfer: uint64(gasLimitTransfer), } account.Balance = account.ThresholdBalance() @@ -177,8 +171,8 @@ func Upgrade(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) return gas, regs, mem, ctxPair, ErrOutOfGas } gas -= UpgradeCost - // let [o, gh, gl, mh, ml] = ω7...12 - addr, gl, gh, ml, mh := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4] + // let [o, g, m] = ω7...10 + addr, gasLimitAccumulator, gasLimitTransfer := regs[A0], regs[A1], regs[A2] // c = μo⋅⋅⋅+32 if No⋅⋅⋅+32 ⊂ Vμ otherwise ∇ codeHash := make([]byte, 32) @@ -186,28 +180,19 @@ func Upgrade(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) return gas, withCode(regs, OOB), mem, ctxPair, nil } - // let g = 2^32 ⋅ gh + gl - gasLimitAccumulator := uint64(gh)<<32 | uint64(gl) - - // let m = 2^32 ⋅ mh + ml - gasLimitTransfer := uint64(mh)<<32 | uint64(ml) - // (ω′7, (X′s)c, (X′s)g , (X′s)m) = (OK, c, g, m) if c ≠ ∇ currentService := ctxPair.RegularCtx.ServiceAccount() currentService.CodeHash = crypto.Hash(codeHash) - currentService.GasLimitForAccumulator = gasLimitAccumulator - currentService.GasLimitOnTransfer = gasLimitTransfer + currentService.GasLimitForAccumulator = uint64(gasLimitAccumulator) + currentService.GasLimitOnTransfer = uint64(gasLimitTransfer) ctxPair.RegularCtx.ServiceState[ctxPair.RegularCtx.ServiceId] = currentService return gas, withCode(regs, OK), mem, ctxPair, nil } // Transfer ΩT(ϱ, ω, μ, (x, y)) func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error) { - // let (d, al, ah, gl, gh, o) = ω7..13 - receiverId, al, ah, gl, gh, o := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4], regs[A5] - - // let a = 2^32 ⋅ ah + al - newBalance := uint64(ah)<<32 | uint64(al) + // let (d, a, g, o) = ω7..11 + receiverId, newBalance, gasLimit, o := regs[A0], regs[A1], regs[A2], regs[A3] transferCost := TransferBaseCost + Gas(newBalance) if gas < transferCost { @@ -215,9 +200,6 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair } gas -= transferCost - // let g = 2^32 ⋅ gh + gl - gasLimit := uint64(gh)<<32 | uint64(gl) - // m = μo⋅⋅⋅+M if No⋅⋅⋅+M ⊂ Vμ otherwise ∇ m := make([]byte, service.TransferMemoSizeBytes) if err := mem.Read(o, m); err != nil { @@ -228,9 +210,9 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair deferredTransfer := service.DeferredTransfer{ SenderServiceIndex: ctxPair.RegularCtx.ServiceId, ReceiverServiceIndex: block.ServiceId(receiverId), - Balance: newBalance, + Balance: uint64(newBalance), Memo: service.Memo(m), - GasLimit: gasLimit, + GasLimit: uint64(gasLimit), } // let d = xd ∪ (xu)d @@ -244,7 +226,7 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair } // if g < (δ ∪ xn)[d]m - if gasLimit < receiverService.GasLimitOnTransfer { + if uint64(gasLimit) < receiverService.GasLimitOnTransfer { return gas, withCode(regs, LOW), mem, ctxPair, nil } @@ -255,7 +237,7 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair // let b = (xs)b − a // if b < (xs)t - if ctxPair.RegularCtx.ServiceAccount().Balance-newBalance < ctxPair.RegularCtx.ServiceAccount().ThresholdBalance() { + if ctxPair.RegularCtx.ServiceAccount().Balance-uint64(newBalance) < ctxPair.RegularCtx.ServiceAccount().ThresholdBalance() { return gas, withCode(regs, CASH), mem, ctxPair, nil } @@ -286,7 +268,7 @@ func Quit(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (G } // if d ∈ {s, 2^32 − 1} - if block.ServiceId(receiverId) == ctxPair.RegularCtx.ServiceId || receiverId == math.MaxUint32 { + if block.ServiceId(receiverId) == ctxPair.RegularCtx.ServiceId || uint64(receiverId) == math.MaxUint64 { delete(ctxPair.RegularCtx.AccumulationState.ServiceState, ctxPair.RegularCtx.ServiceId) return gas, withCode(regs, OK), mem, ctxPair, ErrHalt } diff --git a/internal/polkavm/host_call/accumulate_functions_test.go b/internal/polkavm/host_call/accumulate_functions_test.go index e0933bde..b91021a6 100644 --- a/internal/polkavm/host_call/accumulate_functions_test.go +++ b/internal/polkavm/host_call/accumulate_functions_test.go @@ -1,11 +1,9 @@ package host_call import ( - "maps" "math" "slices" "testing" - "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -75,15 +73,15 @@ func TestAccumulate(t *testing.T) { }{ { name: "empower", - fn: fnStd(Empower), + fn: fnStd(Bless), alloc: alloc{ A3: slices.Concat( - encodeNumber(uint32(123)), - encodeNumber(uint64(12341234)), - encodeNumber(uint32(234)), - encodeNumber(uint64(23452345)), - encodeNumber(uint32(345)), - encodeNumber(uint64(34563456)), + encodeNumber(t, uint32(123)), + encodeNumber(t, uint64(12341234)), + encodeNumber(t, uint32(234)), + encodeNumber(t, uint64(23452345)), + encodeNumber(t, uint32(345)), + encodeNumber(t, uint64(34563456)), ), }, initialRegs: deltaRegs{ @@ -159,17 +157,14 @@ func TestAccumulate(t *testing.T) { expectedDeltaRegs: checkUint64(t, 89), expectedX: checkpointCtx, expectedY: checkpointCtx, - }, { + }, + { name: "new", fn: fnStd(New), alloc: alloc{ A0: hash2bytes(randomHash), }, - initialRegs: merge( - deltaRegs{A1: 123123}, - storeUint64(123124123, A2, A3), - storeUint64(756846353, A4, A5), - ), + initialRegs: deltaRegs{A1: 123123, A2: 123124123, A3: 756846353}, expectedDeltaRegs: deltaRegs{ A0: uint32(currentServiceID), }, @@ -210,16 +205,14 @@ func TestAccumulate(t *testing.T) { }, }, }, - }, { + }, + { name: "upgrade", fn: fnStd(Upgrade), alloc: alloc{ A0: hash2bytes(randomHash), }, - initialRegs: merge( - storeUint64(345345345345, A1, A2), - storeUint64(456456456456, A3, A4), - ), + initialRegs: deltaRegs{A1: 3453453453, A2: 456456456}, expectedDeltaRegs: deltaRegs{ A0: uint32(OK), }, @@ -233,33 +226,31 @@ func TestAccumulate(t *testing.T) { ServiceId: currentServiceID, ServiceState: service.ServiceState{currentServiceID: { CodeHash: randomHash, - GasLimitForAccumulator: 345345345345, - GasLimitOnTransfer: 456456456456, + GasLimitForAccumulator: 3453453453, + GasLimitOnTransfer: 456456456, }}, }, }, { name: "transfer", fn: fnStd(Transfer), alloc: alloc{ - A5: fixedSizeBytes(service.TransferMemoSizeBytes, []byte("memo message")), + A3: fixedSizeBytes(service.TransferMemoSizeBytes, []byte("memo message")), + }, + initialRegs: deltaRegs{ + A0: 1234, // d: receiver + A1: 1000000000, // a + A2: 80, // g }, - initialRegs: merge( - deltaRegs{ - A0: 1234, // d: receiver - }, - storeUint64(100000000000, A1, A2), // a - storeUint64(80, A3, A4), // g - ), expectedDeltaRegs: deltaRegs{ A0: uint32(OK), }, - initialGas: 100000000100, + initialGas: 1000000100, expectedGas: 88, X: AccumulateContext{ ServiceId: block.ServiceId(123123123), ServiceState: service.ServiceState{ block.ServiceId(123123123): { - Balance: 100000000100, + Balance: 1000000100, }, }, AccumulationState: state.AccumulationState{ @@ -281,13 +272,13 @@ func TestAccumulate(t *testing.T) { ServiceId: block.ServiceId(123123123), ServiceState: service.ServiceState{ block.ServiceId(123123123): { - Balance: 100000000100, + Balance: 1000000100, }, }, DeferredTransfers: []service.DeferredTransfer{{ SenderServiceIndex: block.ServiceId(123123123), ReceiverServiceIndex: 1234, - Balance: 100000000000, + Balance: 1000000000, Memo: service.Memo(fixedSizeBytes(service.TransferMemoSizeBytes, []byte("memo message"))), GasLimit: 80, }}, @@ -724,21 +715,6 @@ func checkUint64(t *testing.T, gas uint64) deltaRegs { } } -func storeUint64(i uint64, reg1, reg2 Reg) deltaRegs { - return deltaRegs{ - reg1: uint32(math.Mod(float64(i), 1<<32)), - reg2: uint32(math.Floor(float64(i) / (1 << 32))), - } -} - -func merge[M ~map[K]V, K comparable, V any](dd ...M) M { - result := make(M) - for _, d := range dd { - maps.Copy(result, d) - } - return result -} - func fnStd(fn func(Gas, Registers, Memory, AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error)) hostCall { return func(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair, timeslot jamtime.Timeslot) (Gas, Registers, Memory, AccumulateContextPair, error) { return fn(gas, regs, mem, ctxPair) @@ -775,6 +751,8 @@ func transform[S, S2 any](slice1 []S, fn func(S) S2) (slice []S2) { return slice } -func encodeNumber[T ~uint8 | ~uint16 | ~uint32 | ~uint64](v T) []byte { - return jam.SerializeTrivialNatural(v, uint8(unsafe.Sizeof(v))) +func encodeNumber[T ~uint8 | ~uint16 | ~uint32 | ~uint64](t *testing.T, v T) []byte { + res, err := jam.Marshal(v) + require.NoError(t, err) + return res } diff --git a/internal/polkavm/host_call/common.go b/internal/polkavm/host_call/common.go index 8f09b3d5..545552d5 100644 --- a/internal/polkavm/host_call/common.go +++ b/internal/polkavm/host_call/common.go @@ -13,7 +13,7 @@ const ( ReadCost WriteCost InfoCost - EmpowerCost + BlessCost AssignCost DesignateCost CheckpointCost @@ -31,7 +31,7 @@ const ( ReadID = 2 WriteID = 3 InfoID = 4 - EmpowerID = 5 + BlessID = 5 AssignID = 6 DesignateID = 7 CheckpointID = 8 @@ -93,7 +93,7 @@ func readNumber[U interface{ ~uint32 | ~uint64 }](mem Memory, addr uint32, lengt return } - jam.DeserializeTrivialNatural(b, &u) + err = jam.Unmarshal(b, &u) return } diff --git a/internal/polkavm/host_call/general_functions.go b/internal/polkavm/host_call/general_functions.go index aa42a8b7..8b711d0d 100644 --- a/internal/polkavm/host_call/general_functions.go +++ b/internal/polkavm/host_call/general_functions.go @@ -1,15 +1,13 @@ package host_call import ( - "github.com/eigerco/strawberry/pkg/serialization/codec/jam" "math" - "golang.org/x/crypto/blake2b" - "github.com/eigerco/strawberry/internal/block" "github.com/eigerco/strawberry/internal/crypto" "github.com/eigerco/strawberry/internal/polkavm" "github.com/eigerco/strawberry/internal/service" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" ) type AccountInfo struct { @@ -29,9 +27,8 @@ func GasRemaining(gas polkavm.Gas, regs polkavm.Registers) (polkavm.Gas, polkavm } gas -= GasRemainingCost - // Split the new ϱ' value into its lower and upper parts. - regs[polkavm.A0] = uint32(gas & ((1 << 32) - 1)) - regs[polkavm.A1] = uint32(gas >> 32) + // Set the new ϱ' value into ω′7 + regs[polkavm.A0] = uint32(gas) return gas, regs, nil } @@ -43,11 +40,11 @@ func Lookup(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s servi } gas -= LookupCost - sID := regs[polkavm.A0] + omega7 := regs[polkavm.A0] // Determine the lookup key 'a' a := s - if sID != math.MaxUint32 && sID != uint32(serviceId) { + if uint64(omega7) != math.MaxUint64 && omega7 != uint32(serviceId) { var exists bool // Lookup service account by serviceId in the serviceState a, exists = serviceState[serviceId] @@ -68,7 +65,7 @@ func Lookup(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s servi } // Compute the hash H(µho..ho+32) - hash := blake2b.Sum256(memorySlice) + hash := crypto.HashData(memorySlice) // Lookup value in storage (v) using the hash v, exists := a.Storage[hash] @@ -102,16 +99,16 @@ func Read(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s service } gas -= ReadCost - sID := regs[polkavm.A0] + omega7 := regs[polkavm.A0] ko := regs[polkavm.A1] kz := regs[polkavm.A2] bo := regs[polkavm.A3] bz := regs[polkavm.A4] a := s - if sID != math.MaxUint32 && sID != uint32(serviceId) { + if uint64(omega7) != math.MaxUint64 && omega7 != uint32(serviceId) { var exists bool - a, exists = serviceState[block.ServiceId(sID)] + a, exists = serviceState[block.ServiceId(omega7)] if !exists { return gas, regs, mem, polkavm.ErrAccountNotFound } @@ -125,7 +122,7 @@ func Read(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s service return gas, regs, mem, nil } - serviceIdBytes, err := jam.Marshal(sID) + serviceIdBytes, err := jam.Marshal(omega7) if err != nil { return gas, regs, mem, polkavm.ErrPanicf(err.Error()) } @@ -136,7 +133,7 @@ func Read(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s service hashInput = append(hashInput, keyData...) // Compute the hash H(E4(s) + keyData) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) v, exists := a.Storage[k] if !exists { @@ -183,7 +180,7 @@ func Write(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s servic return gas, regs, mem, s, err } hashInput := append(serviceIdBytes, keyData...) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) a := s if vz == 0 { @@ -222,12 +219,12 @@ func Info(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, serviceId } gas -= InfoCost - sID := regs[polkavm.A0] - omega1 := regs[polkavm.A1] + omega7 := regs[polkavm.A0] + omega8 := regs[polkavm.A1] t, exists := serviceState[serviceId] - if sID != math.MaxUint32 { - t, exists = serviceState[block.ServiceId(sID)] + if uint64(omega7) != math.MaxUint64 { + t, exists = serviceState[block.ServiceId(omega7)] } if !exists { return gas, withCode(regs, NONE), mem, nil @@ -249,7 +246,7 @@ func Info(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, serviceId return gas, regs, mem, polkavm.ErrPanicf(err.Error()) } - if err := mem.Write(omega1, m); err != nil { + if err := mem.Write(omega8, m); err != nil { regs[polkavm.A0] = uint32(OOB) return gas, regs, mem, nil } diff --git a/internal/polkavm/host_call/general_functions_test.go b/internal/polkavm/host_call/general_functions_test.go index d0044f0d..bd0c3cad 100644 --- a/internal/polkavm/host_call/general_functions_test.go +++ b/internal/polkavm/host_call/general_functions_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/blake2b" "github.com/eigerco/strawberry/internal/block" "github.com/eigerco/strawberry/internal/crypto" @@ -31,7 +30,6 @@ func TestGasRemaining(t *testing.T) { initialRegs := polkavm.Registers{ polkavm.RA: polkavm.VmAddressReturnToHost, - polkavm.SP: memoryMap.StackAddressHigh, } initialGas := uint64(100) hostCall := func(hostCall uint32, gasCounter polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, x struct{}) (polkavm.Gas, polkavm.Registers, polkavm.Memory, struct{}, error) { @@ -92,7 +90,7 @@ func TestLookup(t *testing.T) { bo := memoryMap.RWDataAddress + 100 dataToHash := make([]byte, 32) copy(dataToHash, "hash") - hash := blake2b.Sum256(dataToHash) + hash := crypto.HashData(dataToHash) err := mem.Write(ho, dataToHash) require.NoError(t, err) @@ -153,7 +151,7 @@ func TestRead(t *testing.T) { hashInput := make([]byte, 0, len(serviceIdBytes)+len(keyData)) hashInput = append(hashInput, serviceIdBytes...) hashInput = append(hashInput, keyData...) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) sa := service.ServiceAccount{ Storage: map[crypto.Hash][]byte{ @@ -221,7 +219,7 @@ func TestWrite(t *testing.T) { require.NoError(t, err) hashInput := append(serviceIdBytes, keyData...) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) sa := service.ServiceAccount{ Balance: 200, diff --git a/internal/state/merkle/helpers_test.go b/internal/state/merkle/helpers_test.go index cb5c8232..acfa21d9 100644 --- a/internal/state/merkle/helpers_test.go +++ b/internal/state/merkle/helpers_test.go @@ -2,7 +2,6 @@ package state import ( "crypto/ed25519" - "encoding/binary" "errors" "fmt" "github.com/eigerco/strawberry/internal/state" @@ -282,7 +281,8 @@ func RandomSafroleStateWithEpochKeys(t *testing.T) safrole.State { func RandomState(t *testing.T) state.State { services := make(service.ServiceState) for i := 0; i < 10; i++ { - services[block.ServiceId(789)] = RandomServiceAccount(t) + // Use different service IDs for each iteration + services[block.ServiceId(uint32(i+789))] = RandomServiceAccount(t) } return state.State{ @@ -412,7 +412,10 @@ func deserializeServices(state *state.State, serializedState map[crypto.Hash][]b // Check if this is a service account entry (state key starts with 255) if isServiceAccountKey(stateKey) { // Extract service ID from the key - serviceId := extractServiceIdFromKey(stateKey) + serviceId, err := extractServiceIdFromKey(stateKey) + if err != nil { + return err + } // Deserialize the combined fields (CodeHash, Balance, etc.) var combined struct { @@ -450,7 +453,19 @@ func isServiceAccountKey(stateKey crypto.Hash) bool { return stateKey[0] == 255 } -func extractServiceIdFromKey(stateKey crypto.Hash) block.ServiceId { - // Assuming that the service ID is embedded in bytes 1-4 of the key - return block.ServiceId(binary.BigEndian.Uint32(stateKey[1:5])) +func extractServiceIdFromKey(stateKey crypto.Hash) (block.ServiceId, error) { + // Collect service ID bytes from positions 1,3,5,7 into a slice + encodedServiceId := []byte{ + stateKey[1], + stateKey[3], + stateKey[5], + stateKey[7], + } + + var serviceId block.ServiceId + if err := jam.Unmarshal(encodedServiceId, &serviceId); err != nil { + return 0, err + } + + return serviceId, nil } diff --git a/internal/state/merkle/serialization.go b/internal/state/merkle/serialization.go index c2bcc4e2..00c749d6 100644 --- a/internal/state/merkle/serialization.go +++ b/internal/state/merkle/serialization.go @@ -6,6 +6,7 @@ import ( "github.com/eigerco/strawberry/internal/service" "github.com/eigerco/strawberry/internal/state" "github.com/eigerco/strawberry/pkg/serialization/codec/jam" + "math" ) // SerializeState serializes the given state into a map of crypto.Hash to byte arrays, for merklization. @@ -148,7 +149,10 @@ func serializeServiceAccount(serviceId block.ServiceId, serviceAccount service.S encodedFootprintSize, encodedFootprintItems, ) - stateKey := generateStateKey(255, serviceId) + stateKey, err := generateStateKeyInterleavedBasic(255, serviceId) + if err != nil { + return err + } serializedState[stateKey] = combined // Serialize storage and preimage items @@ -160,21 +164,43 @@ func serializeServiceAccount(serviceId block.ServiceId, serviceAccount service.S } func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount service.ServiceAccount, serializedState map[crypto.Hash][]byte) error { + encodedMaxUint32, err := jam.Marshal(math.MaxUint32) + if err != nil { + return err + } for hash, value := range serviceAccount.Storage { encodedValue, err := jam.Marshal(value) if err != nil { return err } - stateKey := generateStateKeyInterleaved(serviceId, hash) + + var combined [32]byte + copy(combined[:4], encodedMaxUint32) + copy(combined[4:], hash[:28]) + stateKey, err := generateStateKeyInterleaved(serviceId, combined) + if err != nil { + return err + } serializedState[stateKey] = encodedValue } + encodedMaxUint32MinusOne, err := jam.Marshal(math.MaxUint32 - 1) + if err != nil { + return err + } for hash, value := range serviceAccount.PreimageLookup { encodedValue, err := jam.Marshal(value) if err != nil { return err } - stateKey := generateStateKeyInterleaved(serviceId, hash) + + var combined [32]byte + copy(combined[:4], encodedMaxUint32MinusOne) + copy(combined[4:], hash[1:29]) + stateKey, err := generateStateKeyInterleaved(serviceId, combined) + if err != nil { + return err + } serializedState[stateKey] = encodedValue } @@ -187,12 +213,15 @@ func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount servi if err != nil { return err } + hashedPreImageHistoricalTimeslots := crypto.HashData(encodedPreImageHistoricalTimeslots) var combined [32]byte copy(combined[:4], encodedLength) - hashNotFirst4Bytes := bitwiseNotExceptFirst4Bytes(key.Hash) - copy(combined[4:], hashNotFirst4Bytes[:]) - stateKey := generateStateKeyInterleaved(serviceId, key.Hash) + copy(combined[4:], hashedPreImageHistoricalTimeslots[2:30]) + stateKey, err := generateStateKeyInterleaved(serviceId, key.Hash) + if err != nil { + return err + } serializedState[stateKey] = encodedPreImageHistoricalTimeslots } return nil diff --git a/internal/state/merkle/serialization_test.go b/internal/state/merkle/serialization_test.go index 62b32b50..8b81f7cd 100644 --- a/internal/state/merkle/serialization_test.go +++ b/internal/state/merkle/serialization_test.go @@ -1,7 +1,6 @@ package state import ( - "fmt" "github.com/eigerco/strawberry/internal/crypto" "github.com/eigerco/strawberry/internal/safrole" "github.com/eigerco/strawberry/pkg/serialization/codec/jam" @@ -11,67 +10,41 @@ import ( ) func TestSerializeState(t *testing.T) { - // Step 1: Generate random state and serialize it + // Generate random state state := RandomState(t) + + // Serialize and log serialized keys encodedState, err := SerializeState(state) require.NoError(t, err) - // Step 2: Deserialize the serialized state + // Deserialize and check results decodedState, err := DeserializeState(encodedState) - assert.NoError(t, err) - assert.NotEmpty(t, decodedState) - - // Step 3: Compare the deserialized state with the original state - - // Compare CoreAuthorizersPool - assert.Equal(t, state.CoreAuthorizersPool, decodedState.CoreAuthorizersPool, "CoreAuthorizersPool mismatch") - - // Compare PendingAuthorizersQueues - assert.Equal(t, state.PendingAuthorizersQueues, decodedState.PendingAuthorizersQueues, "PendingAuthorizersQueues mismatch") - - // Compare RecentBlocks - assert.Equal(t, state.RecentBlocks, decodedState.RecentBlocks, "RecentBlocks mismatch") - - // Compare ValidatorState fields - assert.Equal(t, state.ValidatorState.SafroleState.NextValidators, decodedState.ValidatorState.SafroleState.NextValidators, "NextValidators mismatch") - assert.Equal(t, state.ValidatorState.CurrentValidators, decodedState.ValidatorState.CurrentValidators, "CurrentValidators mismatch") - assert.Equal(t, state.ValidatorState.QueuedValidators, decodedState.ValidatorState.QueuedValidators, "FutureValidators mismatch") - assert.Equal(t, state.ValidatorState.ArchivedValidators, decodedState.ValidatorState.ArchivedValidators, "PreviousValidators mismatch") - assert.Equal(t, state.ValidatorState.SafroleState.RingCommitment, decodedState.ValidatorState.SafroleState.RingCommitment, "RingCommitment mismatch") - - // Ensure SealingKeySeries is correctly deserialized - assert.Equal(t, state.ValidatorState.SafroleState.SealingKeySeries, decodedState.ValidatorState.SafroleState.SealingKeySeries, "SealingKeySeries mismatch") - - // Compare TicketAccumulator - assert.Equal(t, state.ValidatorState.SafroleState.TicketAccumulator, decodedState.ValidatorState.SafroleState.TicketAccumulator, "TicketAccumulator mismatch") - - // Compare EntropyPool - assert.Equal(t, state.EntropyPool, decodedState.EntropyPool, "EntropyPool mismatch") - - // Compare CoreAssignments - assert.Equal(t, state.CoreAssignments, decodedState.CoreAssignments, "CoreAssignments mismatch") - - // Compare TimeslotIndex - assert.Equal(t, state.TimeslotIndex, decodedState.TimeslotIndex, "TimeslotIndex mismatch") - - // Compare PrivilegedServices - assert.Equal(t, state.PrivilegedServices, decodedState.PrivilegedServices, "PrivilegedServices mismatch") + require.NoError(t, err) - // Compare ValidatorStatistics - assert.Equal(t, state.ValidatorStatistics, decodedState.ValidatorStatistics, "ValidatorStatistics mismatch") + // Compare services + assert.Equal(t, len(state.Services), len(decodedState.Services), + "Service map length mismatch (Original: %d, Decoded: %d)", + len(state.Services), len(decodedState.Services)) - // Compare Services - assert.Equal(t, len(state.Services), len(decodedState.Services), "Service map length mismatch") for serviceID, originalService := range state.Services { decodedService, exists := decodedState.Services[serviceID] - require.True(t, exists, fmt.Sprintf("ServiceID %d missing in decoded state", serviceID)) + if !exists { + t.Errorf("Service ID %d missing in decoded state. Original service details: %+v", + serviceID, originalService) + continue + } - // Compare individual fields in ServiceAccount - assert.Equal(t, originalService.CodeHash, decodedService.CodeHash, fmt.Sprintf("Mismatch in CodeHash for ServiceID %d", serviceID)) - assert.Equal(t, originalService.Balance, decodedService.Balance, fmt.Sprintf("Mismatch in Balance for ServiceID %d", serviceID)) - assert.Equal(t, originalService.GasLimitForAccumulator, decodedService.GasLimitForAccumulator, fmt.Sprintf("Mismatch in GasLimitForAccumulator for ServiceID %d", serviceID)) - assert.Equal(t, originalService.GasLimitOnTransfer, decodedService.GasLimitOnTransfer, fmt.Sprintf("Mismatch in GasLimitOnTransfer for ServiceID %d", serviceID)) + assert.Equal(t, originalService.CodeHash, decodedService.CodeHash) + assert.Equal(t, originalService.Balance, decodedService.Balance) + assert.Equal(t, originalService.GasLimitForAccumulator, decodedService.GasLimitForAccumulator) + assert.Equal(t, originalService.GasLimitOnTransfer, decodedService.GasLimitOnTransfer) + } + // Check for extra services in decoded state + for serviceID := range decodedState.Services { + if _, exists := state.Services[serviceID]; !exists { + t.Errorf("Extra service ID %d found in decoded state", serviceID) + } } // Compare Past Judgements @@ -305,41 +278,10 @@ func TestSerializeStateServices(t *testing.T) { require.NoError(t, err) for serviceId := range state.Services { - stateKey := generateStateKey(255, serviceId) + stateKey, err := generateStateKeyInterleavedBasic(255, serviceId) + require.NoError(t, err) hashKey := crypto.Hash(stateKey) assert.Contains(t, serializedState, hashKey) assert.NotEmpty(t, serializedState[hashKey]) } } - -// TestSerializeStateStorage checks the serialization of storage items within services. -func TestSerializeStateStorage(t *testing.T) { - state := RandomState(t) - serializedState, err := SerializeState(state) - require.NoError(t, err) - - for serviceId, serviceAccount := range state.Services { - for hash := range serviceAccount.Storage { - stateKey := generateStateKeyInterleaved(serviceId, hash) - hashKey := crypto.Hash(stateKey) - assert.Contains(t, serializedState, hashKey) - assert.NotEmpty(t, serializedState[hashKey]) - } - } -} - -// TestSerializeStatePreimageMeta checks the serialization of the preimage metadata within services. -func TestSerializeStatePreimageMeta(t *testing.T) { - state := RandomState(t) - serializedState, err := SerializeState(state) - require.NoError(t, err) - - for serviceId, serviceAccount := range state.Services { - for key := range serviceAccount.PreimageMeta { - stateKey := generateStateKeyInterleaved(serviceId, key.Hash) - hashKey := crypto.Hash(stateKey) - assert.Contains(t, serializedState, hashKey) - assert.NotEmpty(t, serializedState[hashKey]) - } - } -} diff --git a/internal/state/merkle/serialization_utils.go b/internal/state/merkle/serialization_utils.go index ff3eb9be..b1350610 100644 --- a/internal/state/merkle/serialization_utils.go +++ b/internal/state/merkle/serialization_utils.go @@ -3,7 +3,7 @@ package state import ( "bytes" "crypto/ed25519" - "encoding/binary" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" "slices" "sort" @@ -23,44 +23,45 @@ func generateStateKeyBasic(i uint8) [32]byte { return result } -// generateStateKey to generate state key based on i and s -func generateStateKey(i uint8, s block.ServiceId) [32]byte { +// generateStateKeyInterleavedBasic to generate state key based on i and s +func generateStateKeyInterleavedBasic(i uint8, s block.ServiceId) ([32]byte, error) { + encodedServiceId, err := jam.Marshal(s) + if err != nil { + return [32]byte{}, err + } + var result [32]byte // Place i as the first byte result[0] = i - // Convert s into a 4-byte buffer and place it starting at result[1] - sBuf := make([]byte, 4) - binary.BigEndian.PutUint32(sBuf, uint32(s)) // s is 4 bytes in BigEndian format - - // Copy the 4-byte sBuf to result starting at index 1 - copy(result[1:], sBuf) + // Place encoded service ID bytes at positions 1,3,5,7 + for j := 0; j < 4; j++ { + result[1+j*2] = encodedServiceId[j] + } - // The rest of result is already zero-padded by default - return result + return result, nil } // Function to interleave the first 4 bytes of s and h, then append the rest of h -func generateStateKeyInterleaved(s block.ServiceId, h [32]byte) [32]byte { - var result [32]byte +func generateStateKeyInterleaved(s block.ServiceId, h [32]byte) ([32]byte, error) { + encodedServiceId, err := jam.Marshal(s) + if err != nil { + return [32]byte{}, err + } - // Convert s into a 4-byte buffer - sBuf := make([]byte, 4) - binary.BigEndian.PutUint32(sBuf, uint32(s)) // s is 4 bytes + var result [32]byte - // Interleave the first 4 bytes of s with the first 4 bytes of h + // Interleave the first 4 bytes of encodedServiceId with the first 4 bytes of h for i := 0; i < 4; i++ { - // Copy the i-th byte from sBuf - result[i*2] = sBuf[i] - // Copy the i-th byte from h + result[i*2] = encodedServiceId[i] result[i*2+1] = h[i] } // Append the rest of h to the result copy(result[8:], h[4:]) - return result + return result, nil } // calculateFootprintSize calculates the storage footprint size (al) based on Equation 94. @@ -112,17 +113,3 @@ func sortByteSlicesCopy(slice interface{}) interface{} { panic("unsupported type for sorting") } } - -// bitwiseNotExceptFirst4Bytes to apply bitwise NOT to all bytes except the first 4 -func bitwiseNotExceptFirst4Bytes(h crypto.Hash) [28]byte { - // Clone the original array into a new one - var result [28]byte - copy(result[:], h[:]) - - // Apply bitwise NOT to all bytes except the first 4 - for i := 4; i < len(result); i++ { - result[i] = ^result[i] - } - - return result -} diff --git a/internal/state/merkle/serialization_utils_test.go b/internal/state/merkle/serialization_utils_test.go index 3f0eaa05..ad58a5c3 100644 --- a/internal/state/merkle/serialization_utils_test.go +++ b/internal/state/merkle/serialization_utils_test.go @@ -1,7 +1,10 @@ package state import ( - "encoding/binary" + "fmt" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" + "github.com/stretchr/testify/require" + "math" "testing" "github.com/eigerco/strawberry/internal/block" @@ -10,25 +13,69 @@ import ( "github.com/stretchr/testify/assert" ) -// TestGenerateStateKey verifies that the state key generation works as expected. -func TestGenerateStateKey(t *testing.T) { - // Test with i and serviceId - i := uint8(1) - serviceId := block.ServiceId(100) - - // Generate the state key - stateKey := generateStateKey(i, serviceId) - - // Verify the length is 32 bytes - assert.Equal(t, 32, len(stateKey)) - - // Verify that the first byte matches i - assert.Equal(t, i, stateKey[0]) +// TestGenerateStateKeyInterleavedBasic verifies that the state key generation works as expected. +func TestGenerateStateKeyInterleavedBasic(t *testing.T) { + tests := []struct { + name string + i uint8 + serviceId block.ServiceId + }{ + { + name: "basic case", + i: 1, + serviceId: 100, + }, + { + name: "max values", + i: 255, + serviceId: block.ServiceId(math.MaxUint32), + }, + { + name: "zero values", + i: 0, + serviceId: 0, + }, + } - // Optionally, verify that the encoded serviceId is in the key - expectedEncodedServiceId := make([]byte, 4) - binary.BigEndian.PutUint32(expectedEncodedServiceId, uint32(serviceId)) - assert.Equal(t, expectedEncodedServiceId, stateKey[1:5]) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate the state key + stateKey, err := generateStateKeyInterleavedBasic(tt.i, tt.serviceId) + require.NoError(t, err) + + // Get encoded service ID for verification + encodedServiceId, err := jam.Marshal(tt.serviceId) + require.NoError(t, err) + + // Verify length is 32 bytes + assert.Equal(t, 32, len(stateKey), "key length should be 32 bytes") + + // Verify first byte is i + assert.Equal(t, tt.i, stateKey[0], "first byte should be i") + + // Verify the interleaved pattern: + // [i, s0, 0, s1, 0, s2, 0, s3, 0, 0, ...] + assert.Equal(t, encodedServiceId[0], stateKey[1], "s0 should be at position 1") + assert.Equal(t, byte(0), stateKey[2], "zero should be at position 2") + assert.Equal(t, encodedServiceId[1], stateKey[3], "s1 should be at position 3") + assert.Equal(t, byte(0), stateKey[4], "zero should be at position 4") + assert.Equal(t, encodedServiceId[2], stateKey[5], "s2 should be at position 5") + assert.Equal(t, byte(0), stateKey[6], "zero should be at position 6") + assert.Equal(t, encodedServiceId[3], stateKey[7], "s3 should be at position 7") + assert.Equal(t, byte(0), stateKey[8], "zero should be at position 8") + + // Verify remaining bytes are zero + for i := 9; i < 32; i++ { + assert.Equal(t, byte(0), stateKey[i], + fmt.Sprintf("byte at position %d should be zero", i)) + } + + // Verify we can extract the service ID back + extractedServiceId, err := extractServiceIdFromKey(crypto.Hash(stateKey)) + require.NoError(t, err) + assert.Equal(t, tt.serviceId, extractedServiceId) + }) + } } // TestGenerateStateKeyInterleaved verifies that the interleaving function works as expected. @@ -36,17 +83,36 @@ func TestGenerateStateKeyInterleaved(t *testing.T) { serviceId := block.ServiceId(1234) hash := crypto.Hash{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + // Get encoded service ID for verification + encodedServiceId, err := jam.Marshal(serviceId) + require.NoError(t, err) + // Generate the interleaved state key - stateKey := generateStateKeyInterleaved(serviceId, hash) + stateKey, err := generateStateKeyInterleaved(serviceId, hash) + require.NoError(t, err) // Verify the length is 32 bytes assert.Equal(t, 32, len(stateKey)) // Verify that the first 8 bytes are interleaved between serviceId and hash - assert.Equal(t, stateKey[0], byte(serviceId>>24)) - assert.Equal(t, stateKey[1], hash[0]) - assert.Equal(t, stateKey[2], byte(serviceId>>16)) - assert.Equal(t, stateKey[3], hash[1]) + assert.Equal(t, encodedServiceId[0], stateKey[0]) + assert.Equal(t, hash[0], stateKey[1]) + assert.Equal(t, encodedServiceId[1], stateKey[2]) + assert.Equal(t, hash[1], stateKey[3]) + assert.Equal(t, encodedServiceId[2], stateKey[4]) + assert.Equal(t, hash[2], stateKey[5]) + assert.Equal(t, encodedServiceId[3], stateKey[6]) + assert.Equal(t, hash[3], stateKey[7]) + + // Verify that remaining bytes from hash are copied correctly + rest := stateKey[8:] + for i := 0; i < len(rest); i++ { + if i < len(hash)-4 { + assert.Equal(t, hash[i+4], rest[i], "hash byte mismatch at position %d", i) + } else { + assert.Equal(t, byte(0), rest[i], "should be zero at position %d", i) + } + } } // TestCalculateFootprintSize checks if the footprint size calculation is correct. @@ -77,20 +143,3 @@ func TestCombineEncoded(t *testing.T) { // Verify the combined result assert.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, combined) } - -// TestBitwiseNotExceptFirst4Bytes checks that the bitwise NOT is applied correctly except the first 4 bytes. -func TestBitwiseNotExceptFirst4Bytes(t *testing.T) { - // Example input hash - inputHash := crypto.Hash{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} - - // Apply the bitwise NOT except the first 4 bytes - result := bitwiseNotExceptFirst4Bytes(inputHash) - - // Verify that the first 4 bytes are unchanged - assert.Equal(t, inputHash[0:4], result[0:4]) - - // Verify that the rest of the bytes are bitwise NOT applied - for i := 4; i < len(result); i++ { - assert.Equal(t, ^inputHash[i], result[i]) - } -} diff --git a/internal/statetransition/accumulate.go b/internal/statetransition/accumulate.go index 5df6613e..97b4464c 100644 --- a/internal/statetransition/accumulate.go +++ b/internal/statetransition/accumulate.go @@ -32,9 +32,9 @@ type Accumulator struct { state *state.State } -// InvokePVM ΨA(U, N_S , N_G, ⟦O⟧) → (U, ⟦T⟧, H?, N_G) Equation 280 +// InvokePVM ΨA(U, N_S , N_G, ⟦O⟧) → (U, ⟦T⟧, H?, N_G) Equation (B.8) func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex block.ServiceId, gas uint64, accOperand []state.AccumulationOperand) (state.AccumulationState, []service.DeferredTransfer, *crypto.Hash, uint64) { - // if d[s]c = ∅ + // if ud[s]c = ∅ if accState.ServiceState[serviceIndex].Code() == nil { ctx, err := a.newCtx(accState, serviceIndex) if err != nil { @@ -62,7 +62,7 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b return ctx.AccumulationState, []service.DeferredTransfer{}, nil, 0 } - // F (equation 283) + // F (equation B.10) hostCallFunc := func(hostCall uint32, gasCounter polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, ctx polkavm.AccumulateContextPair) (polkavm.Gas, polkavm.Registers, polkavm.Memory, polkavm.AccumulateContextPair, error) { // s currentService := accState.ServiceState[serviceIndex] @@ -70,7 +70,6 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b switch hostCall { case host_call.GasID: gasCounter, regs, err = host_call.GasRemaining(gasCounter, regs) - ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService case host_call.LookupID: gasCounter, regs, mem, err = host_call.Lookup(gasCounter, regs, mem, currentService, serviceIndex, ctx.RegularCtx.AccumulationState.ServiceState) ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService @@ -82,9 +81,8 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService case host_call.InfoID: gasCounter, regs, mem, err = host_call.Info(gasCounter, regs, mem, serviceIndex, ctx.RegularCtx.AccumulationState.ServiceState) - ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService - case host_call.EmpowerID: - gasCounter, regs, mem, ctx, err = host_call.Empower(gasCounter, regs, mem, ctx) + case host_call.BlessID: + gasCounter, regs, mem, ctx, err = host_call.Bless(gasCounter, regs, mem, ctx) case host_call.AssignID: gasCounter, regs, mem, ctx, err = host_call.Assign(gasCounter, regs, mem, ctx) case host_call.DesignateID: @@ -110,7 +108,7 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b return gasCounter, regs, mem, ctx, err } - remainingGas, ret, newCtxPair, err := interpreter.InvokeWholeProgram(accState.ServiceState[serviceIndex].Code(), 10, gas, args, hostCallFunc, newCtxPair) + remainingGas, ret, newCtxPair, err := interpreter.InvokeWholeProgram(accState.ServiceState[serviceIndex].Code(), 5, gas, args, hostCallFunc, newCtxPair) if err != nil { errPanic := &polkavm.ErrPanic{} if errors.Is(err, polkavm.ErrOutOfGas) || errors.As(err, &errPanic) { @@ -128,7 +126,7 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b return newCtxPair.RegularCtx.AccumulationState, newCtxPair.RegularCtx.DeferredTransfers, nil, uint64(remainingGas) } -// newCtx (281) +// newCtx (B.9) func (a *Accumulator) newCtx(u state.AccumulationState, serviceIndex block.ServiceId) (polkavm.AccumulateContext, error) { serviceState := maps.Clone(u.ServiceState) delete(serviceState, serviceIndex) @@ -176,6 +174,10 @@ func (a *Accumulator) newServiceID(serviceIndex block.ServiceId) (block.ServiceI hashData := crypto.HashData(hashBytes) newId := block.ServiceId(0) - jam.DeserializeTrivialNatural(hashData[:], &newId) + err = jam.Unmarshal(hashData[:], &newId) + if err != nil { + return 0, err + } + return newId, nil } diff --git a/internal/statetransition/on_transfer.go b/internal/statetransition/on_transfer.go index 29be45d1..dcf867f1 100644 --- a/internal/statetransition/on_transfer.go +++ b/internal/statetransition/on_transfer.go @@ -1,13 +1,14 @@ package statetransition import ( + "log" + "github.com/eigerco/strawberry/internal/block" "github.com/eigerco/strawberry/internal/polkavm" "github.com/eigerco/strawberry/internal/polkavm/host_call" "github.com/eigerco/strawberry/internal/polkavm/interpreter" "github.com/eigerco/strawberry/internal/service" "github.com/eigerco/strawberry/pkg/serialization/codec/jam" - "log" ) // InvokePVMOnTransfer On-Transfer service-account invocation (ΨT). @@ -48,7 +49,7 @@ func InvokePVMOnTransfer(serviceState service.ServiceState, serviceIndex block.S return gasCounter, regs, mem, serviceAccount, err } - _, _, newServiceAccount, err := interpreter.InvokeWholeProgram(serviceCode, 15, gas, args, hostCallFunc, serviceAccount) + _, _, newServiceAccount, err := interpreter.InvokeWholeProgram(serviceCode, 10, gas, args, hostCallFunc, serviceAccount) if err != nil { // TODO handle errors appropriately log.Println("the virtual machine exited with an error", err) diff --git a/pkg/serialization/codec/jam/decode.go b/pkg/serialization/codec/jam/decode.go index ef7bc7d7..9a762ad9 100644 --- a/pkg/serialization/codec/jam/decode.go +++ b/pkg/serialization/codec/jam/decode.go @@ -341,7 +341,7 @@ func (br *byteReader) decodeUint(value reflect.Value) error { } var v uint64 - err = DeserializeUint64WithLength(serialized, l, &v) + err = deserializeUint64WithLength(serialized, l, &v) if err != nil { return fmt.Errorf(ErrDecodingUint, err) } @@ -419,19 +419,19 @@ func (br *byteReader) decodeFixedWidthInt(dstv reflect.Value) error { switch in.(type) { case uint8: var temp uint8 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) case uint16: var temp uint16 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) case uint32: var temp uint32 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) case uint64: var temp uint64 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) } diff --git a/pkg/serialization/codec/jam/encode.go b/pkg/serialization/codec/jam/encode.go index 7001d06f..17a08d49 100644 --- a/pkg/serialization/codec/jam/encode.go +++ b/pkg/serialization/codec/jam/encode.go @@ -292,13 +292,13 @@ func (bw *byteWriter) encodeFixedWidthUint(i interface{}) error { switch v := i.(type) { case uint8: - data = SerializeTrivialNatural(v, 1) + data = serializeTrivialNatural(v, 1) case uint16: - data = SerializeTrivialNatural(v, 2) + data = serializeTrivialNatural(v, 2) case uint32: - data = SerializeTrivialNatural(v, 4) + data = serializeTrivialNatural(v, 4) case uint64: - data = SerializeTrivialNatural(v, 8) + data = serializeTrivialNatural(v, 8) default: return fmt.Errorf(ErrUnsupportedType, i) } @@ -341,7 +341,7 @@ func (bw *byteWriter) encodeLength(l int) error { } func (bw *byteWriter) encodeUint(i uint) error { - encodedBytes := SerializeUint64(uint64(i)) + encodedBytes := serializeUint64(uint64(i)) _, err := bw.Write(encodedBytes) diff --git a/pkg/serialization/codec/jam/general_natural.go b/pkg/serialization/codec/jam/general_natural.go index cff50ac5..72d28dcb 100644 --- a/pkg/serialization/codec/jam/general_natural.go +++ b/pkg/serialization/codec/jam/general_natural.go @@ -5,8 +5,8 @@ import ( "math" ) -// SerializeUint64 implements the general formula (able to encode naturals of up to 2^64) -func SerializeUint64(x uint64) []byte { +// serializeUint64 implements the general formula (able to encode naturals of up to 2^64) +func serializeUint64(x uint64) []byte { var l uint8 // Determine the length needed to represent the value for l = 0; l < 8; l++ { @@ -30,8 +30,8 @@ func SerializeUint64(x uint64) []byte { return bytes } -// DeserializeUint64WithLength deserializes a byte slice into a uint64 value, with length `l`. -func DeserializeUint64WithLength(serialized []byte, l uint8, u *uint64) error { +// deserializeUint64WithLength deserializes a byte slice into a uint64 value, with length `l`. +func deserializeUint64WithLength(serialized []byte, l uint8, u *uint64) error { *u = 0 n := len(serialized) diff --git a/pkg/serialization/codec/jam/general_natural_test.go b/pkg/serialization/codec/jam/general_natural_test.go index 958824b8..218660d6 100644 --- a/pkg/serialization/codec/jam/general_natural_test.go +++ b/pkg/serialization/codec/jam/general_natural_test.go @@ -2,11 +2,12 @@ package jam import ( "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "math" "math/bits" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecodeUint64(t *testing.T) { @@ -53,7 +54,7 @@ func TestEncodeDecodeUint64(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) { // Marshal the x value - serialized := SerializeUint64(tc.input) + serialized := serializeUint64(tc.input) // Check if the serialized output matches the expected output assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input) @@ -64,7 +65,7 @@ func TestEncodeDecodeUint64(t *testing.T) { } // Unmarshal the serialized data back into a uint64 var deserialized uint64 - err := DeserializeUint64WithLength(serialized, l, &deserialized) + err := deserializeUint64WithLength(serialized, l, &deserialized) require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized) // Check if the deserialized value matches the original x diff --git a/pkg/serialization/codec/jam/trivial_natural.go b/pkg/serialization/codec/jam/trivial_natural.go index aa982504..1b3b6fe2 100644 --- a/pkg/serialization/codec/jam/trivial_natural.go +++ b/pkg/serialization/codec/jam/trivial_natural.go @@ -4,7 +4,7 @@ import ( "math" ) -func SerializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte { +func serializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte { bytes := make([]byte, l) for i := uint8(0); i < l; i++ { bytes[i] = byte((x >> (8 * i)) & T(math.MaxUint8)) @@ -12,7 +12,7 @@ func SerializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint return bytes } -func DeserializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](serialized []byte, u *T) { +func deserializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](serialized []byte, u *T) { *u = 0 // Iterate over each byte in the serialized array diff --git a/pkg/serialization/codec/jam/trivial_natural_test.go b/pkg/serialization/codec/jam/trivial_natural_test.go index 10a77e9d..9f8e78d6 100644 --- a/pkg/serialization/codec/jam/trivial_natural_test.go +++ b/pkg/serialization/codec/jam/trivial_natural_test.go @@ -49,13 +49,13 @@ func TestSerializationTrivialNatural(t *testing.T) { var serialized []byte switch v := tc.x.(type) { case uint8: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) case uint16: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) case uint32: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) case uint64: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) } assert.Equal(t, tc.expected, serialized, "serialized output mismatch") @@ -63,19 +63,19 @@ func TestSerializationTrivialNatural(t *testing.T) { switch v := tc.x.(type) { case uint8: var deserialized uint8 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") case uint16: var deserialized uint16 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") case uint32: var deserialized uint32 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") case uint64: var deserialized uint64 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") } })