Skip to content

Commit

Permalink
Merge pull request #123 from zama-ai/davidk/implement-ciphertext-cache
Browse files Browse the repository at this point in the history
feat: implement global ciphertext cache
  • Loading branch information
david-zk authored Nov 14, 2024
2 parents ee83ddc + d36a1fc commit bc78218
Showing 1 changed file with 119 additions and 4 deletions.
123 changes: 119 additions & 4 deletions fhevm-engine/fhevm-go-native/fhevm/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/big"
"os"
"sort"
"sync"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -121,7 +122,7 @@ type ExecutorApi interface {
CreateSession(blockNumber int64, storage ChainStorageApi) ExecutorSession
// Insert existing fhe operations to the state from inside the state
// storage queue. This should be called at the end of every block.
FlushFheResultsToState(blockNumber int64, storage ChainStorageApi) ExecutorSession
FlushFheResultsToState(blockNumber int64, storage ChainStorageApi) error
}

type SegmentId int
Expand All @@ -147,10 +148,24 @@ type ComputationStore interface {
InsertComputationBatch(ciphertexts []ComputationToInsert) error
}

type CacheBlockData struct {
// store ciphertexts by handles
materializedCiphertexts map[string][]byte
// allow inserting many at once
blockEnqueuedCiphertext map[string]bool
enqueuedCiphertexts []*ComputationToInsert
}

type CiphertextCache struct {
lock sync.RWMutex
blocksCiphertexts map[int64]*CacheBlockData
}

type ApiImpl struct {
address common.Address
aclContractAddress common.Address
contractStorageAddress common.Address
cache *CiphertextCache
}

type SessionImpl struct {
Expand Down Expand Up @@ -189,6 +204,7 @@ type EvmStorageComputationStore struct {
evmStorage ChainStorageApi
currentBlockNumber int64
contractStorageAddress common.Address
cache *CiphertextCache
}

type handleOffset struct {
Expand Down Expand Up @@ -217,6 +233,7 @@ func (executorApi *ApiImpl) CreateSession(blockNumber int64, api ChainStorageApi
evmStorage: api,
contractStorageAddress: executorApi.contractStorageAddress,
currentBlockNumber: blockNumber,
cache: executorApi.cache,
},
},
}
Expand Down Expand Up @@ -494,29 +511,64 @@ func (dbApi *EvmStorageComputationStore) InsertComputationBatch(computations []C
dbApi.evmStorage.SetState(dbApi.contractStorageAddress, countAddress, common.BigToHash(ciphertextsInBlock))
}

// enqueue items to cache, we do this in the
// end because it requires locking, so lock for minimal time
dbApi.cache.lock.Lock()
defer dbApi.cache.lock.Unlock()

// TODO: implement cache warmup algorithm, when we restart blockchain
// we want to scan storage queue for computations to be completed

for _, key := range allKeys {
queueBlockNumber := int64(key)
bucket := buckets[queueBlockNumber]
ctsStorage := dbApi.cache.blocksCiphertexts[queueBlockNumber]
if ctsStorage == nil {
ctsStorage = &CacheBlockData{
materializedCiphertexts: make(map[string][]byte),
blockEnqueuedCiphertext: make(map[string]bool),
enqueuedCiphertexts: make([]*ComputationToInsert, 0),
}
dbApi.cache.blocksCiphertexts[queueBlockNumber] = ctsStorage
}

for _, comp := range bucket {
// don't have duplicates, from possibly evaluating multiple trie caches
if !ctsStorage.blockEnqueuedCiphertext[common.Bytes2Hex(comp.OutputHandle)] {
ctsStorage.enqueuedCiphertexts = append(ctsStorage.enqueuedCiphertexts, comp)
}
}
}

return nil
}

func (executorApi *ApiImpl) FlushFheResultsToState(blockNumber int64, api ChainStorageApi) ExecutorSession {
func (executorApi *ApiImpl) FlushFheResultsToState(blockNumber int64, api ChainStorageApi) error {
// cleanup the queue for the block number
countAddress := blockNumberToQueueItemCountAddress(blockNumber)
ciphertextsInBlock := api.GetState(executorApi.contractStorageAddress, countAddress).Big()
ctCount := ciphertextsInBlock.Int64()
zero := common.BigToHash(big.NewInt(0))
one := big.NewInt(1)

// make sure handles are materialized in storage in deterministic
// order, first come first serve basis in the queue
handlesToMaterialize := make([]common.Hash, 0)

// zero out queue ciphertexts
for i := 0; i < int(ctCount); i++ {
ctAddr := blockQueueStorageLayout(blockNumber, int64(i))
metadata := bytesToMetadata(api.GetState(executorApi.contractStorageAddress, ctAddr.metadata))
outputHandle := api.GetState(executorApi.contractStorageAddress, ctAddr.outputHandle)
handlesToMaterialize = append(handlesToMaterialize, outputHandle)
api.SetState(executorApi.contractStorageAddress, ctAddr.metadata, zero)
api.SetState(executorApi.contractStorageAddress, ctAddr.outputHandle, zero)
api.SetState(executorApi.contractStorageAddress, ctAddr.firstOperand, zero)
api.SetState(executorApi.contractStorageAddress, ctAddr.secondOperand, zero)
if metadata.IsBigScalar {
counter := new(big.Int)
counter.SetBytes(ctAddr.bigScalarOperand[:])
// max supporter number 2048 is 2048
// max supported number 2048 is 2048
for i := 0; i < 2048/256; i++ {
api.SetState(executorApi.contractStorageAddress, common.BigToHash(counter), zero)
counter.Add(counter, one)
Expand All @@ -527,7 +579,65 @@ func (executorApi *ApiImpl) FlushFheResultsToState(blockNumber int64, api ChainS
// set 0 as count
api.SetState(executorApi.contractStorageAddress, countAddress, zero)

panic("TODO: implement flushing of ciphertext data to the blockchain state")
// materialize handles in storage assuming they exist in the cache
return executorApi.materializeHandlesInStorage(blockNumber, handlesToMaterialize, api)
}

func (executorApi *ApiImpl) materializeHandlesInStorage(blockNumber int64, handles []common.Hash, api ChainStorageApi) error {
// no one did fhe computations in the block
if len(handles) == 0 {
return nil
}

executorApi.cache.lock.Lock()
defer executorApi.cache.lock.Unlock()

contractAddr := executorApi.contractStorageAddress

blockData, ok := executorApi.cache.blocksCiphertexts[blockNumber]
if !ok {
return errors.New("block number not found for materialized ciphertexts")
}

for _, handle := range handles {
hexStr := common.Bytes2Hex(handle[:])
ciphertext, ok := blockData.materializedCiphertexts[hexStr]
if !ok {
return errors.New("ciphertext not found in cache")
}

ctLength := big.NewInt(int64(len(ciphertext)))

startAddress := new(big.Int)
startAddress.SetBytes(handle[:])
wordAddress := func(word int64) common.Hash {
res := big.NewInt(word)
res.Add(res, startAddress)
return common.BigToHash(res)
}

// write ciphertext length first
api.SetState(contractAddr, handle, common.BigToHash(ctLength))

// write the ciphertext by uint256 chunks
wholeBlocks := len(ciphertext) / 32
tailBlockSize := len(ciphertext) % 32

// first block starts at handle + 1
wordOffset := int64(1)
for i := 0; i < wholeBlocks; i++ {
ctSlice := common.BytesToHash(ciphertext[i*32 : i*32+32])
api.SetState(contractAddr, wordAddress(wordOffset), ctSlice)
wordOffset += 1
}
// write the last partial block if it exists
if tailBlockSize > 0 {
ctSlice := common.BytesToHash(ciphertext[wholeBlocks*32 : wholeBlocks*32+tailBlockSize])
api.SetState(contractAddr, wordAddress(wordOffset), ctSlice)
}
}

return nil
}

func (dbApi *EvmStorageComputationStore) InsertComputation(computation ComputationToInsert) error {
Expand Down Expand Up @@ -555,10 +665,15 @@ func InitExecutor() (ExecutorApi, error) {

// pick hardcoded value in the beginning, we can change later
storageAddress := common.HexToAddress("0x0000000000000000000000000000000000000070")

apiImpl := ApiImpl{
address: fhevmContractAddress,
aclContractAddress: aclContractAddress,
contractStorageAddress: storageAddress,
cache: &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]CacheBlockData),
},
}

return &apiImpl, nil
Expand Down

0 comments on commit bc78218

Please sign in to comment.