Skip to content

Commit

Permalink
Fix implementation to reflect graypaper
Browse files Browse the repository at this point in the history
  • Loading branch information
carlos-romano committed Nov 27, 2024
1 parent ce25b75 commit d0e50c2
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 47 deletions.
27 changes: 19 additions & 8 deletions internal/state/merkle/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,10 @@ func deserializeServices(state *state.State, serializedState map[crypto.Hash][]b
// Check if this is a service account entry (state key starts with 255)
if isServiceAccountKey(stateKey) {
// Extract service ID from the key
serviceId := extractServiceIdFromKey(stateKey)
serviceId, err := extractServiceIdFromKey(stateKey)
if err != nil {
return err
}

// Deserialize the combined fields (CodeHash, Balance, etc.)
var combined struct {
Expand Down Expand Up @@ -450,11 +453,19 @@ func isServiceAccountKey(stateKey crypto.Hash) bool {
return stateKey[0] == 255
}

func extractServiceIdFromKey(stateKey crypto.Hash) block.ServiceId {
// Since we're using zero-interleaved pattern, we need to reconstruct the service ID
// from bytes at positions 1,3,5,7
return block.ServiceId(uint32(stateKey[1])<<24 |
uint32(stateKey[3])<<16 |
uint32(stateKey[5])<<8 |
uint32(stateKey[7]))
func extractServiceIdFromKey(stateKey crypto.Hash) (block.ServiceId, error) {
// Collect service ID bytes from positions 1,3,5,7 into a slice
encodedServiceId := []byte{
stateKey[1],
stateKey[3],
stateKey[5],
stateKey[7],
}

var serviceId block.ServiceId
if err := jam.Unmarshal(encodedServiceId, &serviceId); err != nil {
return 0, err
}

return serviceId, nil
}
20 changes: 16 additions & 4 deletions internal/state/merkle/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ func serializeServiceAccount(serviceId block.ServiceId, serviceAccount service.S
encodedFootprintSize,
encodedFootprintItems,
)
stateKey := generateStateKeyInterleavedBasic(255, serviceId)
stateKey, err := generateStateKeyInterleavedBasic(255, serviceId)
if err != nil {
return err
}
serializedState[stateKey] = combined

// Serialize storage and preimage items
Expand All @@ -174,7 +177,10 @@ func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount servi
var combined [32]byte
copy(combined[:4], encodedMaxUint32)
copy(combined[4:], hash[:28])
stateKey := generateStateKeyInterleaved(serviceId, combined)
stateKey, err := generateStateKeyInterleaved(serviceId, combined)
if err != nil {
return err
}
serializedState[stateKey] = encodedValue
}

Expand All @@ -191,7 +197,10 @@ func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount servi
var combined [32]byte
copy(combined[:4], encodedMaxUint32MinusOne)
copy(combined[4:], hash[1:29])
stateKey := generateStateKeyInterleaved(serviceId, combined)
stateKey, err := generateStateKeyInterleaved(serviceId, combined)
if err != nil {
return err
}
serializedState[stateKey] = encodedValue
}

Expand All @@ -209,7 +218,10 @@ func serializeStorageAndPreimage(serviceId block.ServiceId, serviceAccount servi
var combined [32]byte
copy(combined[:4], encodedLength)
copy(combined[4:], hashedPreImageHistoricalTimeslots[2:30])
stateKey := generateStateKeyInterleaved(serviceId, key.Hash)
stateKey, err := generateStateKeyInterleaved(serviceId, key.Hash)
if err != nil {
return err
}
serializedState[stateKey] = encodedPreImageHistoricalTimeslots
}
return nil
Expand Down
3 changes: 2 additions & 1 deletion internal/state/merkle/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ func TestSerializeStateServices(t *testing.T) {
require.NoError(t, err)

for serviceId := range state.Services {
stateKey := generateStateKeyInterleavedBasic(255, serviceId)
stateKey, err := generateStateKeyInterleavedBasic(255, serviceId)
require.NoError(t, err)
hashKey := crypto.Hash(stateKey)
assert.Contains(t, serializedState, hashKey)
assert.NotEmpty(t, serializedState[hashKey])
Expand Down
40 changes: 21 additions & 19 deletions internal/state/merkle/serialization_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package state
import (
"bytes"
"crypto/ed25519"
"encoding/binary"
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
"slices"
"sort"

Expand All @@ -24,42 +24,44 @@ func generateStateKeyBasic(i uint8) [32]byte {
}

// generateStateKeyInterleavedBasic to generate state key based on i and s
func generateStateKeyInterleavedBasic(i uint8, s block.ServiceId) [32]byte {
func generateStateKeyInterleavedBasic(i uint8, s block.ServiceId) ([32]byte, error) {
encodedServiceId, err := jam.Marshal(s)
if err != nil {
return [32]byte{}, err
}

var result [32]byte

// Place i as the first byte
result[0] = i

// Extract individual bytes from s using bit shifting
result[1] = byte(s >> 24) // n0
result[3] = byte(s >> 16) // n1
result[5] = byte(s >> 8) // n2
result[7] = byte(s) // n3
// Place encoded service ID bytes at positions 1,3,5,7
for j := 0; j < 4; j++ {
result[1+j*2] = encodedServiceId[j]
}

// result[2,4,6,8] and the rest are already 0 by default
return result
return result, nil
}

// Function to interleave the first 4 bytes of s and h, then append the rest of h
func generateStateKeyInterleaved(s block.ServiceId, h [32]byte) [32]byte {
var result [32]byte
func generateStateKeyInterleaved(s block.ServiceId, h [32]byte) ([32]byte, error) {
encodedServiceId, err := jam.Marshal(s)
if err != nil {
return [32]byte{}, err
}

// Convert s into a 4-byte buffer
sBuf := make([]byte, 4)
binary.BigEndian.PutUint32(sBuf, uint32(s)) // s is 4 bytes
var result [32]byte

// Interleave the first 4 bytes of s with the first 4 bytes of h
// Interleave the first 4 bytes of encodedServiceId with the first 4 bytes of h
for i := 0; i < 4; i++ {
// Copy the i-th byte from sBuf
result[i*2] = sBuf[i]
// Copy the i-th byte from h
result[i*2] = encodedServiceId[i]
result[i*2+1] = h[i]
}

// Append the rest of h to the result
copy(result[8:], h[4:])

return result
return result, nil
}

// calculateFootprintSize calculates the storage footprint size (al) based on Equation 94.
Expand Down
56 changes: 41 additions & 15 deletions internal/state/merkle/serialization_utils_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package state

import (
"encoding/binary"
"fmt"
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
"github.com/stretchr/testify/require"
"math"
"testing"

Expand Down Expand Up @@ -39,11 +40,12 @@ func TestGenerateStateKeyInterleavedBasic(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate the state key
stateKey := generateStateKeyInterleavedBasic(tt.i, tt.serviceId)
stateKey, err := generateStateKeyInterleavedBasic(tt.i, tt.serviceId)
require.NoError(t, err)

// Convert serviceId to bytes for verification
serviceIdBytes := make([]byte, 4)
binary.BigEndian.PutUint32(serviceIdBytes, uint32(tt.serviceId))
// Get encoded service ID for verification
encodedServiceId, err := jam.Marshal(tt.serviceId)
require.NoError(t, err)

// Verify length is 32 bytes
assert.Equal(t, 32, len(stateKey), "key length should be 32 bytes")
Expand All @@ -52,21 +54,26 @@ func TestGenerateStateKeyInterleavedBasic(t *testing.T) {
assert.Equal(t, tt.i, stateKey[0], "first byte should be i")

// Verify the interleaved pattern:
// [i, n0, 0, n1, 0, n2, 0, n3, 0, 0, ...]
assert.Equal(t, serviceIdBytes[0], stateKey[1], "n0 should be at position 1")
// [i, s0, 0, s1, 0, s2, 0, s3, 0, 0, ...]
assert.Equal(t, encodedServiceId[0], stateKey[1], "s0 should be at position 1")
assert.Equal(t, byte(0), stateKey[2], "zero should be at position 2")
assert.Equal(t, serviceIdBytes[1], stateKey[3], "n1 should be at position 3")
assert.Equal(t, encodedServiceId[1], stateKey[3], "s1 should be at position 3")
assert.Equal(t, byte(0), stateKey[4], "zero should be at position 4")
assert.Equal(t, serviceIdBytes[2], stateKey[5], "n2 should be at position 5")
assert.Equal(t, encodedServiceId[2], stateKey[5], "s2 should be at position 5")
assert.Equal(t, byte(0), stateKey[6], "zero should be at position 6")
assert.Equal(t, serviceIdBytes[3], stateKey[7], "n3 should be at position 7")
assert.Equal(t, encodedServiceId[3], stateKey[7], "s3 should be at position 7")
assert.Equal(t, byte(0), stateKey[8], "zero should be at position 8")

// Verify remaining bytes are zero
for i := 9; i < 32; i++ {
assert.Equal(t, byte(0), stateKey[i],
fmt.Sprintf("byte at position %d should be zero", i))
}

// Verify we can extract the service ID back
extractedServiceId, err := extractServiceIdFromKey(crypto.Hash(stateKey))
require.NoError(t, err)
assert.Equal(t, tt.serviceId, extractedServiceId)
})
}
}
Expand All @@ -76,17 +83,36 @@ func TestGenerateStateKeyInterleaved(t *testing.T) {
serviceId := block.ServiceId(1234)
hash := crypto.Hash{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}

// Get encoded service ID for verification
encodedServiceId, err := jam.Marshal(serviceId)
require.NoError(t, err)

// Generate the interleaved state key
stateKey := generateStateKeyInterleaved(serviceId, hash)
stateKey, err := generateStateKeyInterleaved(serviceId, hash)
require.NoError(t, err)

// Verify the length is 32 bytes
assert.Equal(t, 32, len(stateKey))

// Verify that the first 8 bytes are interleaved between serviceId and hash
assert.Equal(t, stateKey[0], byte(serviceId>>24))
assert.Equal(t, stateKey[1], hash[0])
assert.Equal(t, stateKey[2], byte(serviceId>>16))
assert.Equal(t, stateKey[3], hash[1])
assert.Equal(t, encodedServiceId[0], stateKey[0])
assert.Equal(t, hash[0], stateKey[1])
assert.Equal(t, encodedServiceId[1], stateKey[2])
assert.Equal(t, hash[1], stateKey[3])
assert.Equal(t, encodedServiceId[2], stateKey[4])
assert.Equal(t, hash[2], stateKey[5])
assert.Equal(t, encodedServiceId[3], stateKey[6])
assert.Equal(t, hash[3], stateKey[7])

// Verify that remaining bytes from hash are copied correctly
rest := stateKey[8:]
for i := 0; i < len(rest); i++ {
if i < len(hash)-4 {
assert.Equal(t, hash[i+4], rest[i], "hash byte mismatch at position %d", i)
} else {
assert.Equal(t, byte(0), rest[i], "should be zero at position %d", i)
}
}
}

// TestCalculateFootprintSize checks if the footprint size calculation is correct.
Expand Down

0 comments on commit d0e50c2

Please sign in to comment.