Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement ciphertext cache preload upon restart #187

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 127 additions & 9 deletions fhevm-engine/fhevm-go-native/fhevm/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ type ExecutorApi interface {
// We pass current block number to know at which
// block ciphertext should be materialized inside blockchain state.
CreateSession(blockNumber int64) ExecutorSession
// Preload ciphertexts into cache and perform initial computations,
// should be called once after blockchain node initialization
PreloadCiphertexts(blockNumber int64, api ChainStorageApi) error
}

type SegmentId int
Expand Down Expand Up @@ -230,6 +233,114 @@ func (executorApi *ApiImpl) CreateSession(blockNumber int64) ExecutorSession {
}
}

func (executorApi *ApiImpl) PreloadCiphertexts(blockNumber int64, api ChainStorageApi) error {
computations := executorApi.loadComputationsFromStateToCache(blockNumber, api)
if computations > 0 {
return executorProcessPendingComputations(executorApi)
}

return nil
}

func (executorApi *ApiImpl) loadComputationsFromStateToCache(startBlockNumber int64, api ChainStorageApi) int {
loadStartTime := time.Now()
computations := 0
defer func() {
duration := time.Since(loadStartTime)
fmt.Printf("ciphertext cache preloaded with %d ciphertexts in %dms\n", computations, duration.Milliseconds())
}()

// TODO: figure out the limit how long in future blocks we should preload
lastBlockToPreload := startBlockNumber + 30

executorApi.cache.lock.Lock()
defer executorApi.cache.lock.Unlock()

for block := startBlockNumber; block < lastBlockToPreload; block++ {
countAddress := blockNumberToQueueItemCountAddress(block)
ciphertextsInBlock := api.GetState(executorApi.contractStorageAddress, countAddress).Big()
inBlock := ciphertextsInBlock.Int64()
queue := make([]*ComputationToInsert, 0)
enqueuedCiphertext := make(map[string]bool)

if inBlock == 0 {
continue
}

computations += int(inBlock)

for ctNum := 0; ctNum < int(inBlock); ctNum++ {
layout := blockQueueStorageLayout(block, int64(ctNum))
metadata := bytesToMetadata(api.GetState(executorApi.contractStorageAddress, layout.metadata))
outputHandle := api.GetState(executorApi.contractStorageAddress, layout.outputHandle)
computation := &ComputationToInsert{
segmentId: 0,
Operation: metadata.Operation,
OutputHandle: outputHandle[:],
CommitBlockId: block,
}

if isBinaryOp(metadata.Operation) {
firstOpHandle := api.GetState(executorApi.contractStorageAddress, layout.firstOperand)
firstOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, firstOpHandle)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: firstOpHandle[:],
CompressedCiphertext: firstOpCt,
FheUintType: handleType(firstOpHandle[:]),
})

if metadata.IsBigScalar {
// TODO: implement big scalar
} else if metadata.IsScalar {
secondOpHandle := api.GetState(executorApi.contractStorageAddress, layout.secondOperand)
computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: true,
Handle: secondOpHandle[:],
FheUintType: handleType(firstOpHandle[:]),
})
} else {
secondOpHandle := api.GetState(executorApi.contractStorageAddress, layout.secondOperand)
secondOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, secondOpHandle)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: secondOpHandle[:],
CompressedCiphertext: secondOpCt,
FheUintType: handleType(secondOpHandle[:]),
})
}
} else if isUnaryOp(metadata.Operation) {
firstOpAddress := api.GetState(executorApi.contractStorageAddress, layout.firstOperand)
firstOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, firstOpAddress)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: firstOpAddress[:],
CompressedCiphertext: firstOpCt,
FheUintType: handleType(firstOpAddress[:]),
})
} else {
// TODO: handle all special functions to load their ciphertext arguments
}

if !enqueuedCiphertext[string(computation.OutputHandle)] {
queue = append(queue, computation)
enqueuedCiphertext[string(computation.OutputHandle)] = true
}
}

ctsToCompute := &BlockCiphertextQueue{
queue: queue,
enqueuedCiphertext: enqueuedCiphertext,
}
executorApi.cache.ciphertextsToCompute[block] = ctsToCompute
}

return computations
}

func (sessionApi *SessionImpl) Commit(blockNumber int64, storage ChainStorageApi) error {
err := sessionApi.sessionStore.Commit(storage)
if err != nil {
Expand Down Expand Up @@ -530,12 +641,13 @@ func (dbApi *EvmStorageComputationStore) InsertComputationBatch(evmStorage Chain

for _, comp := range bucket {
// don't have duplicates, from possibly evaluating multiple trie caches
if !ctsStorage.enqueuedCiphertext[common.Bytes2Hex(comp.OutputHandle)] {
if !ctsStorage.enqueuedCiphertext[string(comp.OutputHandle)] {
// we must fill the raw ciphertext values here from storage so cache
// would have ciphertexts to compute on, as cache doesn't have easy
// access to the evm state
dbApi.hydrateComputationFromEvmState(evmStorage, comp)
ctsStorage.queue = append(ctsStorage.queue, comp)
ctsStorage.enqueuedCiphertext[string(comp.OutputHandle)] = true
}
}
}
Expand Down Expand Up @@ -766,18 +878,20 @@ func InitExecutor() (ExecutorApi, error) {

workAvailableChan := make(chan bool, 10)

cache := &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]*CacheBlockData),
ciphertextsToCompute: make(map[int64]*BlockCiphertextQueue),
workAvailableChan: workAvailableChan,
lastCacheGc: time.Now(),
}

apiImpl := &ApiImpl{
address: fhevmContractAddress,
aclContractAddress: aclContractAddress,
contractStorageAddress: storageAddress,
executorUrl: executorUrl,
cache: &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]*CacheBlockData),
ciphertextsToCompute: make(map[int64]*BlockCiphertextQueue),
workAvailableChan: workAvailableChan,
lastCacheGc: time.Now(),
},
cache: cache,
}

// run executor worker in the background
Expand Down Expand Up @@ -885,8 +999,12 @@ func executorProcessPendingComputations(impl *ApiImpl) error {
if err != nil {
return err
}
ciphertexts := response.GetResultCiphertexts()
if ciphertexts == nil {
return errors.New(response.GetError().String())
}

outCts := response.GetResultCiphertexts().Ciphertexts
outCts := ciphertexts.Ciphertexts
fmt.Printf("got %d ciphertext responses from the executor\n", len(outCts))
for _, ct := range outCts {
theBlock, exists := ctToBlockIndex[string(ct.Handle)]
Expand Down
22 changes: 22 additions & 0 deletions fhevm-engine/fhevm-go-native/fhevm/fhelib_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,25 @@ func getThreeFheOperands(sess ExecutorSession, input []byte) (first []byte, seco

return input[0:32], input[32:64], input[64:96], nil
}

func isBinaryOp(op FheOp) bool {
switch op {
case FheAdd, FheBitAnd, FheBitOr, FheBitXor, FheDiv, FheEq, FheGe, FheGt, FheLe, FheLt, FheMax, FheMin, FheMul, FheNe, FheRem, FheRotl, FheRotr, FheShl, FheShr, FheSub:
return true
case FheCast, FheNeg, FheNot, FheRand, FheRandBounded, FheIfThenElse, TrivialEncrypt:
return false
default:
return false
}
}

func isUnaryOp(op FheOp) bool {
switch op {
case FheNeg, FheNot:
return true
case FheAdd, FheBitAnd, FheBitOr, FheBitXor, FheDiv, FheEq, FheGe, FheGt, FheLe, FheLt, FheMax, FheMin, FheMul, FheNe, FheRem, FheRotl, FheRotr, FheShl, FheShr, FheSub, FheCast, FheRand, FheRandBounded, FheIfThenElse, TrivialEncrypt:
return false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for nitpicking - but might even drop the false branch and leave it to default :)

If we don't have ternary ops we might otherwise just implement binary as !unary, but that's not essential.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish it was more like rust, where we would have to list all the cases or its a compile time error, we might forget an operand 🤔

default:
return false
}
}
Loading