Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove-Get concurrency fix #27

Draft
wants to merge 8 commits into
base: rc/v1.7.next1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions leveldb/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package leveldb

// PutBatch will call the unexported putBatch function
func (s *SerialDB) PutBatch() {
_ = s.putBatch()
}
67 changes: 53 additions & 14 deletions leveldb/leveldbSerial.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/multiversx/mx-chain-core-go/core/closing"
"github.com/multiversx/mx-chain-storage-go/common"
"github.com/multiversx/mx-chain-storage-go/types"
Expand All @@ -25,11 +26,14 @@ type SerialDB struct {
maxBatchSize int
batchDelaySeconds int
sizeBatch int
batch types.Batcher
mutBatch sync.RWMutex
dbAccess chan serialQueryer
cancel context.CancelFunc
closer core.SafeCloser

accessBatch types.Batcher
writingBatch types.Batcher
mutBatch sync.RWMutex

dbAccess chan serialQueryer
cancel context.CancelFunc
closer core.SafeCloser
}

// NewSerialDB is a constructor for the leveldb persister
Expand Down Expand Up @@ -80,7 +84,7 @@ func NewSerialDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFi
closer: closing.NewSafeChanCloser(),
}

dbStore.batch = NewBatch()
dbStore.accessBatch = NewBatch()

go dbStore.batchTimeoutHandle(ctx)
go dbStore.processLoop(ctx)
Expand Down Expand Up @@ -142,7 +146,7 @@ func (s *SerialDB) Put(key, val []byte) error {
}

s.mutBatch.RLock()
err := s.batch.Put(key, val)
err := s.accessBatch.Put(key, val)
s.mutBatch.RUnlock()
if err != nil {
return err
Expand All @@ -158,12 +162,12 @@ func (s *SerialDB) Get(key []byte) ([]byte, error) {
}

s.mutBatch.RLock()
if s.batch.IsRemoved(key) {
if s.isRemoved(key) {
s.mutBatch.RUnlock()
return nil, common.ErrKeyNotFound
}

data := s.batch.Get(key)
data := s.getFromBatches(key)
s.mutBatch.RUnlock()

if data != nil {
Expand Down Expand Up @@ -200,12 +204,12 @@ func (s *SerialDB) Has(key []byte) error {
}

s.mutBatch.RLock()
if s.batch.IsRemoved(key) {
if s.isRemoved(key) {
s.mutBatch.RUnlock()
return common.ErrKeyNotFound
}

data := s.batch.Get(key)
data := s.getFromBatches(key)
s.mutBatch.RUnlock()

if data != nil {
Expand All @@ -228,6 +232,30 @@ func (s *SerialDB) Has(key []byte) error {
return result
}

func (s *SerialDB) isRemoved(key []byte) bool {
if s.accessBatch.IsRemoved(key) {
return true
}
if check.IfNil(s.writingBatch) {
return false
}

return s.writingBatch.IsRemoved(key)
}

func (s *SerialDB) getFromBatches(key []byte) []byte {
// start testing the access batch as it will contain the most up-to-date variant
data := s.accessBatch.Get(key)
if data != nil {
return data
}
if check.IfNil(s.writingBatch) {
return nil
}

return s.writingBatch.Get(key)
}

func (s *SerialDB) tryWriteInDbAccessChan(req serialQueryer) error {
select {
case s.dbAccess <- req:
Expand All @@ -240,13 +268,20 @@ func (s *SerialDB) tryWriteInDbAccessChan(req serialQueryer) error {
// putBatch writes the Batch data into the database
func (s *SerialDB) putBatch() error {
s.mutBatch.Lock()
dbBatch, ok := s.batch.(*batch)
if !check.IfNil(s.writingBatch) {
s.mutBatch.Unlock()
return nil
}

s.writingBatch = s.accessBatch

dbBatch, ok := s.writingBatch.(*batch)
if !ok {
s.mutBatch.Unlock()
return common.ErrInvalidBatch
}
s.sizeBatch = 0
s.batch = NewBatch()
s.accessBatch = NewBatch()
s.mutBatch.Unlock()

ch := make(chan error)
Expand All @@ -262,6 +297,10 @@ func (s *SerialDB) putBatch() error {
result := <-ch
close(ch)

s.mutBatch.Lock()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a great solution, this can just set to nil the writing batch if another go-routine called putBatch and managed to set the writingBatch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, reverted

s.writingBatch = nil
s.mutBatch.Unlock()

return result
}

Expand All @@ -287,7 +326,7 @@ func (s *SerialDB) Remove(key []byte) error {
}

s.mutBatch.Lock()
_ = s.batch.Delete(key)
_ = s.accessBatch.Delete(key)
s.mutBatch.Unlock()

return s.updateBatchWithIncrement()
Expand Down
58 changes: 58 additions & 0 deletions leveldb/leveldbSerial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math/big"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -358,3 +359,60 @@ func TestSerialDB_ConcurrentOperations(t *testing.T) {

wg.Wait()
}

func TestSerialDB_PutRemoveGet(t *testing.T) {
if testing.Short() {
t.Skip("this is not a short test")
}

ldb := createSerialLevelDb(t, 100000, 1000000, 10)

numKeys := 10000
for i := 0; i < numKeys; i++ {
_ = ldb.Put([]byte(fmt.Sprintf("key %d", i)), []byte("val"))
}

time.Sleep(time.Second * 2)

numErr := uint32(0)

for i := 0; i < numKeys; i++ {
key := []byte(fmt.Sprintf("key %d", i))

recoveredVal, _ := ldb.Get(key)
assert.NotEmpty(t, recoveredVal)

wg := &sync.WaitGroup{}
wg.Add(2)

// emulate the following scenario:
// the sequence Remove(key) -> Get(key) is done while the putBatch is called. So the actual edgecase is
// go routine 1: Remove(key) -----------------> Get(key)
// go routine 2: putBatch()

go func() {
time.Sleep(time.Millisecond * 1)
ldb.PutBatch()
wg.Done()
}()
go func() {
_ = ldb.Remove(key)

time.Sleep(time.Millisecond * 1)

recoveredVal2, _ := ldb.Get(key)
if len(recoveredVal2) > 0 {
// the key-value was not removed
atomic.AddUint32(&numErr, 1)
}

wg.Done()
}()

wg.Wait()

require.Zero(t, atomic.LoadUint32(&numErr), "iteration %d out of %d", i, numKeys)
}

_ = ldb.Close()
}