Skip to content

Commit

Permalink
State serialization to 0.5.0 (#147)
Browse files Browse the repository at this point in the history
* Adapt state serialization to 0.5.0

* Improve comments

* Rename generate state key function

* Improve encoding

* Fix implementation to reflect graypaper

* Merge branch 'main' into feat/state-serialization-0.5.0
  • Loading branch information
carlos-romano authored Nov 27, 2024
1 parent 83ff6e8 commit d791611
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 172 deletions.
27 changes: 21 additions & 6 deletions internal/state/merkle/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package state

import (
"crypto/ed25519"
"encoding/binary"
"errors"
"fmt"
"github.com/eigerco/strawberry/internal/state"
Expand Down Expand Up @@ -282,7 +281,8 @@ func RandomSafroleStateWithEpochKeys(t *testing.T) safrole.State {
func RandomState(t *testing.T) state.State {
services := make(service.ServiceState)
for i := 0; i < 10; i++ {
services[block.ServiceId(789)] = RandomServiceAccount(t)
// Use different service IDs for each iteration
services[block.ServiceId(uint32(i+789))] = RandomServiceAccount(t)
}

return state.State{
Expand Down Expand Up @@ -412,7 +412,10 @@ func deserializeServices(state *state.State, serializedState map[crypto.Hash][]b
// Check if this is a service account entry (state key starts with 255)
if isServiceAccountKey(stateKey) {
// Extract service ID from the key
serviceId := extractServiceIdFromKey(stateKey)
serviceId, err := extractServiceIdFromKey(stateKey)
if err != nil {
return err
}

// Deserialize the combined fields (CodeHash, Balance, etc.)
var combined struct {
Expand Down Expand Up @@ -450,7 +453,19 @@ func isServiceAccountKey(stateKey crypto.Hash) bool {
return stateKey[0] == 255
}

func extractServiceIdFromKey(stateKey crypto.Hash) block.ServiceId {
// Assuming that the service ID is embedded in bytes 1-4 of the key
return block.ServiceId(binary.BigEndian.Uint32(stateKey[1:5]))
func extractServiceIdFromKey(stateKey crypto.Hash) (block.ServiceId, error) {
// Collect service ID bytes from positions 1,3,5,7 into a slice
encodedServiceId := []byte{
stateKey[1],
stateKey[3],
stateKey[5],
stateKey[7],
}

var serviceId block.ServiceId
if err := jam.Unmarshal(encodedServiceId, &serviceId); err != nil {
return 0, err
}

return serviceId, nil
}
41 changes: 35 additions & 6 deletions internal/state/merkle/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/eigerco/strawberry/internal/service"
"github.com/eigerco/strawberry/internal/state"
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
"math"
)

// SerializeState serializes the given state into a map of crypto.Hash to byte arrays, for merklization.
Expand Down Expand Up @@ -148,7 +149,10 @@ func serializeServiceAccount(serviceId block.ServiceId, serviceAccount service.S
encodedFootprintSize,
encodedFootprintItems,
)
stateKey := generateStateKey(255, serviceId)
stateKey, err := generateStateKeyInterleavedBasic(255, serviceId)
if err != nil {
return err
}
serializedState[stateKey] = combined

// Serialize storage and preimage items
Expand All @@ -160,21 +164,43 @@ func serializeServiceAccount(serviceId block.ServiceId, serviceAccount service.S
}

func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount service.ServiceAccount, serializedState map[crypto.Hash][]byte) error {
encodedMaxUint32, err := jam.Marshal(math.MaxUint32)
if err != nil {
return err
}
for hash, value := range serviceAccount.Storage {
encodedValue, err := jam.Marshal(value)
if err != nil {
return err
}
stateKey := generateStateKeyInterleaved(serviceId, hash)

var combined [32]byte
copy(combined[:4], encodedMaxUint32)
copy(combined[4:], hash[:28])
stateKey, err := generateStateKeyInterleaved(serviceId, combined)
if err != nil {
return err
}
serializedState[stateKey] = encodedValue
}

encodedMaxUint32MinusOne, err := jam.Marshal(math.MaxUint32 - 1)
if err != nil {
return err
}
for hash, value := range serviceAccount.PreimageLookup {
encodedValue, err := jam.Marshal(value)
if err != nil {
return err
}
stateKey := generateStateKeyInterleaved(serviceId, hash)

var combined [32]byte
copy(combined[:4], encodedMaxUint32MinusOne)
copy(combined[4:], hash[1:29])
stateKey, err := generateStateKeyInterleaved(serviceId, combined)
if err != nil {
return err
}
serializedState[stateKey] = encodedValue
}

Expand All @@ -187,12 +213,15 @@ func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount servi
if err != nil {
return err
}
hashedPreImageHistoricalTimeslots := crypto.HashData(encodedPreImageHistoricalTimeslots)

var combined [32]byte
copy(combined[:4], encodedLength)
hashNotFirst4Bytes := bitwiseNotExceptFirst4Bytes(key.Hash)
copy(combined[4:], hashNotFirst4Bytes[:])
stateKey := generateStateKeyInterleaved(serviceId, key.Hash)
copy(combined[4:], hashedPreImageHistoricalTimeslots[2:30])
stateKey, err := generateStateKeyInterleaved(serviceId, key.Hash)
if err != nil {
return err
}
serializedState[stateKey] = encodedPreImageHistoricalTimeslots
}
return nil
Expand Down
110 changes: 26 additions & 84 deletions internal/state/merkle/serialization_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package state

import (
"fmt"
"github.com/eigerco/strawberry/internal/crypto"
"github.com/eigerco/strawberry/internal/safrole"
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
Expand All @@ -11,67 +10,41 @@ import (
)

func TestSerializeState(t *testing.T) {
// Step 1: Generate random state and serialize it
// Generate random state
state := RandomState(t)

// Serialize and log serialized keys
encodedState, err := SerializeState(state)
require.NoError(t, err)

// Step 2: Deserialize the serialized state
// Deserialize and check results
decodedState, err := DeserializeState(encodedState)
assert.NoError(t, err)
assert.NotEmpty(t, decodedState)

// Step 3: Compare the deserialized state with the original state

// Compare CoreAuthorizersPool
assert.Equal(t, state.CoreAuthorizersPool, decodedState.CoreAuthorizersPool, "CoreAuthorizersPool mismatch")

// Compare PendingAuthorizersQueues
assert.Equal(t, state.PendingAuthorizersQueues, decodedState.PendingAuthorizersQueues, "PendingAuthorizersQueues mismatch")

// Compare RecentBlocks
assert.Equal(t, state.RecentBlocks, decodedState.RecentBlocks, "RecentBlocks mismatch")

// Compare ValidatorState fields
assert.Equal(t, state.ValidatorState.SafroleState.NextValidators, decodedState.ValidatorState.SafroleState.NextValidators, "NextValidators mismatch")
assert.Equal(t, state.ValidatorState.CurrentValidators, decodedState.ValidatorState.CurrentValidators, "CurrentValidators mismatch")
assert.Equal(t, state.ValidatorState.QueuedValidators, decodedState.ValidatorState.QueuedValidators, "FutureValidators mismatch")
assert.Equal(t, state.ValidatorState.ArchivedValidators, decodedState.ValidatorState.ArchivedValidators, "PreviousValidators mismatch")
assert.Equal(t, state.ValidatorState.SafroleState.RingCommitment, decodedState.ValidatorState.SafroleState.RingCommitment, "RingCommitment mismatch")

// Ensure SealingKeySeries is correctly deserialized
assert.Equal(t, state.ValidatorState.SafroleState.SealingKeySeries, decodedState.ValidatorState.SafroleState.SealingKeySeries, "SealingKeySeries mismatch")

// Compare TicketAccumulator
assert.Equal(t, state.ValidatorState.SafroleState.TicketAccumulator, decodedState.ValidatorState.SafroleState.TicketAccumulator, "TicketAccumulator mismatch")

// Compare EntropyPool
assert.Equal(t, state.EntropyPool, decodedState.EntropyPool, "EntropyPool mismatch")

// Compare CoreAssignments
assert.Equal(t, state.CoreAssignments, decodedState.CoreAssignments, "CoreAssignments mismatch")

// Compare TimeslotIndex
assert.Equal(t, state.TimeslotIndex, decodedState.TimeslotIndex, "TimeslotIndex mismatch")

// Compare PrivilegedServices
assert.Equal(t, state.PrivilegedServices, decodedState.PrivilegedServices, "PrivilegedServices mismatch")
require.NoError(t, err)

// Compare ValidatorStatistics
assert.Equal(t, state.ValidatorStatistics, decodedState.ValidatorStatistics, "ValidatorStatistics mismatch")
// Compare services
assert.Equal(t, len(state.Services), len(decodedState.Services),
"Service map length mismatch (Original: %d, Decoded: %d)",
len(state.Services), len(decodedState.Services))

// Compare Services
assert.Equal(t, len(state.Services), len(decodedState.Services), "Service map length mismatch")
for serviceID, originalService := range state.Services {
decodedService, exists := decodedState.Services[serviceID]
require.True(t, exists, fmt.Sprintf("ServiceID %d missing in decoded state", serviceID))
if !exists {
t.Errorf("Service ID %d missing in decoded state. Original service details: %+v",
serviceID, originalService)
continue
}

// Compare individual fields in ServiceAccount
assert.Equal(t, originalService.CodeHash, decodedService.CodeHash, fmt.Sprintf("Mismatch in CodeHash for ServiceID %d", serviceID))
assert.Equal(t, originalService.Balance, decodedService.Balance, fmt.Sprintf("Mismatch in Balance for ServiceID %d", serviceID))
assert.Equal(t, originalService.GasLimitForAccumulator, decodedService.GasLimitForAccumulator, fmt.Sprintf("Mismatch in GasLimitForAccumulator for ServiceID %d", serviceID))
assert.Equal(t, originalService.GasLimitOnTransfer, decodedService.GasLimitOnTransfer, fmt.Sprintf("Mismatch in GasLimitOnTransfer for ServiceID %d", serviceID))
assert.Equal(t, originalService.CodeHash, decodedService.CodeHash)
assert.Equal(t, originalService.Balance, decodedService.Balance)
assert.Equal(t, originalService.GasLimitForAccumulator, decodedService.GasLimitForAccumulator)
assert.Equal(t, originalService.GasLimitOnTransfer, decodedService.GasLimitOnTransfer)
}

// Check for extra services in decoded state
for serviceID := range decodedState.Services {
if _, exists := state.Services[serviceID]; !exists {
t.Errorf("Extra service ID %d found in decoded state", serviceID)
}
}

// Compare Past Judgements
Expand Down Expand Up @@ -305,41 +278,10 @@ func TestSerializeStateServices(t *testing.T) {
require.NoError(t, err)

for serviceId := range state.Services {
stateKey := generateStateKey(255, serviceId)
stateKey, err := generateStateKeyInterleavedBasic(255, serviceId)
require.NoError(t, err)
hashKey := crypto.Hash(stateKey)
assert.Contains(t, serializedState, hashKey)
assert.NotEmpty(t, serializedState[hashKey])
}
}

// TestSerializeStateStorage checks the serialization of storage items within services.
func TestSerializeStateStorage(t *testing.T) {
state := RandomState(t)
serializedState, err := SerializeState(state)
require.NoError(t, err)

for serviceId, serviceAccount := range state.Services {
for hash := range serviceAccount.Storage {
stateKey := generateStateKeyInterleaved(serviceId, hash)
hashKey := crypto.Hash(stateKey)
assert.Contains(t, serializedState, hashKey)
assert.NotEmpty(t, serializedState[hashKey])
}
}
}

// TestSerializeStatePreimageMeta checks the serialization of the preimage metadata within services.
func TestSerializeStatePreimageMeta(t *testing.T) {
state := RandomState(t)
serializedState, err := SerializeState(state)
require.NoError(t, err)

for serviceId, serviceAccount := range state.Services {
for key := range serviceAccount.PreimageMeta {
stateKey := generateStateKeyInterleaved(serviceId, key.Hash)
hashKey := crypto.Hash(stateKey)
assert.Contains(t, serializedState, hashKey)
assert.NotEmpty(t, serializedState[hashKey])
}
}
}
57 changes: 22 additions & 35 deletions internal/state/merkle/serialization_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package state
import (
"bytes"
"crypto/ed25519"
"encoding/binary"
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
"slices"
"sort"

Expand All @@ -23,44 +23,45 @@ func generateStateKeyBasic(i uint8) [32]byte {
return result
}

// generateStateKey to generate state key based on i and s
func generateStateKey(i uint8, s block.ServiceId) [32]byte {
// generateStateKeyInterleavedBasic to generate state key based on i and s
func generateStateKeyInterleavedBasic(i uint8, s block.ServiceId) ([32]byte, error) {
encodedServiceId, err := jam.Marshal(s)
if err != nil {
return [32]byte{}, err
}

var result [32]byte

// Place i as the first byte
result[0] = i

// Convert s into a 4-byte buffer and place it starting at result[1]
sBuf := make([]byte, 4)
binary.BigEndian.PutUint32(sBuf, uint32(s)) // s is 4 bytes in BigEndian format

// Copy the 4-byte sBuf to result starting at index 1
copy(result[1:], sBuf)
// Place encoded service ID bytes at positions 1,3,5,7
for j := 0; j < 4; j++ {
result[1+j*2] = encodedServiceId[j]
}

// The rest of result is already zero-padded by default
return result
return result, nil
}

// Function to interleave the first 4 bytes of s and h, then append the rest of h
func generateStateKeyInterleaved(s block.ServiceId, h [32]byte) [32]byte {
var result [32]byte
func generateStateKeyInterleaved(s block.ServiceId, h [32]byte) ([32]byte, error) {
encodedServiceId, err := jam.Marshal(s)
if err != nil {
return [32]byte{}, err
}

// Convert s into a 4-byte buffer
sBuf := make([]byte, 4)
binary.BigEndian.PutUint32(sBuf, uint32(s)) // s is 4 bytes
var result [32]byte

// Interleave the first 4 bytes of s with the first 4 bytes of h
// Interleave the first 4 bytes of encodedServiceId with the first 4 bytes of h
for i := 0; i < 4; i++ {
// Copy the i-th byte from sBuf
result[i*2] = sBuf[i]
// Copy the i-th byte from h
result[i*2] = encodedServiceId[i]
result[i*2+1] = h[i]
}

// Append the rest of h to the result
copy(result[8:], h[4:])

return result
return result, nil
}

// calculateFootprintSize calculates the storage footprint size (al) based on Equation 94.
Expand Down Expand Up @@ -112,17 +113,3 @@ func sortByteSlicesCopy(slice interface{}) interface{} {
panic("unsupported type for sorting")
}
}

// bitwiseNotExceptFirst4Bytes to apply bitwise NOT to all bytes except the first 4
func bitwiseNotExceptFirst4Bytes(h crypto.Hash) [28]byte {
// Clone the original array into a new one
var result [28]byte
copy(result[:], h[:])

// Apply bitwise NOT to all bytes except the first 4
for i := 4; i < len(result); i++ {
result[i] = ^result[i]
}

return result
}
Loading

0 comments on commit d791611

Please sign in to comment.