Skip to content

Commit

Permalink
Removed edge cases and errors from sqlite logdb
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpert committed Sep 7, 2022
1 parent 5aca1f7 commit 8b5e508
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 51 deletions.
4 changes: 2 additions & 2 deletions lib/log_db_script.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
CREATE TABLE IF NOT EXISTS raft_info(
entry_index UNSIGNED BIG INT,
node_id UNSIGNED BIG INT NOT NULL,
cluster_id UNSIGNED BIG INT,
entry_type INTEGER,
cluster_id UNSIGNED BIG INT NOT NULL,
entry_type INTEGER NOT NULL,
payload BLOB
);

Expand Down
173 changes: 124 additions & 49 deletions lib/sqlite_log_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package lib
import (
"database/sql"
"fmt"
"math"

_ "embed"

Expand All @@ -11,6 +12,7 @@ import (
"github.com/lni/dragonboat/v3/raftio"
"github.com/lni/dragonboat/v3/raftpb"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"marmot/db"
)
Expand Down Expand Up @@ -167,7 +169,7 @@ func (s *SQLiteLogDB) GetBootstrapInfo(clusterID uint64, nodeID uint64) (raftpb.
return bs, nil
}

func (s *SQLiteLogDB) SaveRaftState(updates []raftpb.Update, shardID uint64) error {
func (s *SQLiteLogDB) SaveRaftState(updates []raftpb.Update, _ uint64) error {
return s.db.WithTx(func(tx *goqu.TxDatabase) error {
for _, upd := range updates {
if !raftpb.IsEmptyState(upd.State) {
Expand All @@ -191,13 +193,10 @@ func (s *SQLiteLogDB) SaveRaftState(updates []raftpb.Update, shardID uint64) err
}
}

if len(upd.EntriesToSave) > 0 {
for _, entry := range upd.EntriesToSave {
// nodeID, clusterID, entry.Index
err := saveInfoTuple(tx, &entry.Index, upd.NodeID, upd.ClusterID, Entry, entry.Marshal)
if err != nil {
return err
}
for _, entry := range upd.EntriesToSave {
err := saveInfoTuple(tx, &entry.Index, upd.NodeID, upd.ClusterID, Entry, entry.Marshal)
if err != nil {
return err
}
}
}
Expand Down Expand Up @@ -232,6 +231,28 @@ func (s *SQLiteLogDB) IterateEntries(
high uint64,
maxSize uint64,
) ([]raftpb.Entry, uint64, error) {
logger := log.With().
Uint64("low", low).
Uint64("size", size).
Uint64("high", high).
Uint64("node_id", nodeID).
Uint64("max_size", maxSize).
Uint64("cluster_id", clusterID).
Uint64("entries", uint64(len(entries))).
Logger()

min, count, err := s.getEntryRange(nodeID, clusterID)
if err == raftio.ErrNoSavedLog {
logger.Warn().Msg("No entries...")
return entries, size, nil
}

if err != nil {
logger.Warn().Err(err).Msg("Range error")
return entries, size, err
}

logger = logger.With().Uint64("min", min).Uint64("max", min+count-1).Logger()
rows, err := s.db.
Select("payload").
From(raftInfoTable).
Expand All @@ -251,8 +272,7 @@ func (s *SQLiteLogDB) IterateEntries(
eRow := &db.EnhancedRows{Rows: rows}
defer eRow.Finalize()

currentSize := uint64(0)
ret := make([]raftpb.Entry, 0)
expectedIndex := low
for eRow.Next() {
bts := make([]byte, 0)
err = eRow.Scan(&bts)
Expand All @@ -265,31 +285,40 @@ func (s *SQLiteLogDB) IterateEntries(
if err != nil {
return entries, size, err
}

ret = append(ret, e)
currentSize += uint64(e.SizeUpperLimit())
if currentSize > maxSize {
if e.Index != expectedIndex {
logger.Warn().Msg(fmt.Sprintf("Index mismatch %d != %d", e.Index, expectedIndex))
break
}
}

if len(ret) == 0 {
return entries, size, nil
size += uint64(e.SizeUpperLimit())
entries = append(entries, e)
expectedIndex++

if size >= maxSize {
logger.Trace().Msg(fmt.Sprintf("Size mismatch %d != %d", size, maxSize))
break
}
}

return ret, currentSize, nil
logger.Trace().Msg("Scan complete")
return entries, size, nil
}

func (s *SQLiteLogDB) ReadRaftState(clusterID uint64, nodeID uint64, snapshotIndex uint64) (raftio.RaftState, error) {
entry := raftInfoEntry{}
ret := raftio.RaftState{}
log.Trace().Msg(fmt.Sprintf("ReadRaftState %d %d %d", clusterID, nodeID, snapshotIndex))

firstIndex, entriesCount, err := s.getEntryRange(nodeID, clusterID)
if err != nil {
return ret, err
}

ok, err := s.db.From(raftInfoTable).
Where(goqu.Ex{
"entry_index": snapshotIndex,
"node_id": nodeID,
"cluster_id": clusterID,
"entry_type": State,
"node_id": nodeID,
"cluster_id": clusterID,
"entry_type": State,
}).ScanStruct(&entry)

if err != nil {
Expand All @@ -300,11 +329,6 @@ func (s *SQLiteLogDB) ReadRaftState(clusterID uint64, nodeID uint64, snapshotInd
return ret, raftio.ErrNoSavedLog
}

firstIndex, entriesCount, err := s.getRange(nodeID, clusterID)
if err != nil {
return ret, err
}

err = ret.State.Unmarshal(entry.Payload)
if err != nil {
return ret, err
Expand All @@ -313,31 +337,34 @@ func (s *SQLiteLogDB) ReadRaftState(clusterID uint64, nodeID uint64, snapshotInd
ret.FirstIndex = firstIndex
ret.EntryCount = entriesCount

// Have to investigate but existing code in dragonboat suggests with 1 entry it returns 0
if snapshotIndex == (firstIndex + entriesCount - 1) {
ret.EntryCount = 0
}

return ret, nil
}

// RemoveEntriesTo removes entries associated with the specified Raft node up
// to the specified index.
func (s *SQLiteLogDB) RemoveEntriesTo(clusterID uint64, nodeID uint64, index uint64) error {
return s.db.WithTx(func(tx *goqu.TxDatabase) error {
log.Trace().Msg(fmt.Sprintf("RemoveEntriesTo c: %d n: %d i: %d", clusterID, nodeID, index))
return deleteInfoTuple(tx, nodeID, clusterID, Entry, []goqu.Expression{
goqu.C("entry_index").Lte(index),
goqu.C("entry_index").Lt(index),
})
})
}

func (s *SQLiteLogDB) CompactEntriesTo(clusterID uint64, nodeID uint64, index uint64) (<-chan struct{}, error) {
ch := make(chan struct{})
err := s.DeleteSnapshot(clusterID, nodeID, index)
if err != nil {
return nil, err
}
log.Trace().Msg(fmt.Sprintf("CompactEntriesTo c: %d n: %d i: %d", clusterID, nodeID, index))

defer func() {
close(ch)
}()

return nil, err
return ch, nil
}

func (s *SQLiteLogDB) SaveSnapshots(updates []raftpb.Update) error {
Expand All @@ -363,13 +390,18 @@ func (s *SQLiteLogDB) DeleteSnapshot(clusterID uint64, nodeID uint64, index uint

func (s *SQLiteLogDB) ListSnapshots(clusterID uint64, nodeID uint64, index uint64) ([]raftpb.Snapshot, error) {
entries := make([]raftInfoEntry, 0)
exps := []goqu.Expression{
goqu.C("node_id").Eq(nodeID),
goqu.C("cluster_id").Eq(clusterID),
goqu.C("entry_type").Eq(Snapshot),
}
if index != math.MaxUint64 {
exps = append(exps, goqu.C("entry_index").Lte(index))
}

err := s.db.
From(raftInfoTable).
Where(
goqu.C("node_id").Eq(nodeID),
goqu.C("cluster_id").Eq(clusterID),
goqu.C("entry_type").Eq(Snapshot),
goqu.C("entry_index").Lte(index)).
Where(exps...).
Order(goqu.C("entry_index").Asc()).
Prepared(true).
ScanStructs(&entries)
Expand All @@ -385,6 +417,8 @@ func (s *SQLiteLogDB) ListSnapshots(clusterID uint64, nodeID uint64, index uint6
if err != nil {
return nil, err
}

ret = append(ret, snp)
}

return ret, nil
Expand All @@ -394,12 +428,12 @@ func (s *SQLiteLogDB) ImportSnapshot(_ raftpb.Snapshot, _ uint64) error {
return nil
}

func (s *SQLiteLogDB) getRange(nodeID, clusterID uint64) (uint64, uint64, error) {
func (s *SQLiteLogDB) getEntryRange(nodeID, clusterID uint64) (uint64, uint64, error) {
count := uint64(0)
ok, err := s.db.From(raftInfoTable).Select(
goqu.COUNT(goqu.C("entry_index")),
goqu.COUNT("entry_index"),
).
Where(goqu.Ex{"node_id": nodeID, "cluster_id": clusterID}).
Where(goqu.Ex{"node_id": nodeID, "cluster_id": clusterID, "entry_type": Entry}).
Prepared(true).
Executor().
ScanVal(&count)
Expand All @@ -418,9 +452,9 @@ func (s *SQLiteLogDB) getRange(nodeID, clusterID uint64) (uint64, uint64, error)

min := uint64(0)
ok, err = s.db.From(raftInfoTable).Select(
goqu.MIN(goqu.C("entry_index")),
goqu.MIN("entry_index"),
).
Where(goqu.Ex{"node_id": nodeID, "cluster_id": clusterID}).
Where(goqu.Ex{"node_id": nodeID, "cluster_id": clusterID, "entry_type": Entry}).
Prepared(true).
Executor().
ScanVal(&min)
Expand All @@ -443,6 +477,17 @@ func deleteInfoTuple(
entryType raftInfoEntryType,
additionalExpressions []goqu.Expression,
) error {

logger := log.With().
Uint64("node_id", nodeID).
Uint64("cluster_id", clusterID).
Uint64("entry_type", uint64(entryType)).
Logger()

for i, x := range additionalExpressions {
logger = logger.With().Str(fmt.Sprintf("exp[%d]", i), fmt.Sprintf("%v", x)).Logger()
}

exps := []goqu.Expression{
goqu.C("node_id").Eq(nodeID),
goqu.C("cluster_id").Eq(clusterID),
Expand All @@ -451,6 +496,7 @@ func deleteInfoTuple(

exps = append(exps, additionalExpressions...)

logger.Trace().Msg("Deleted rows")
_, err := db.Delete(raftInfoTable).
Where(exps...).
Prepared(true).
Expand All @@ -460,12 +506,11 @@ func deleteInfoTuple(
if err != nil {
return err
}

return nil
}

func saveInfoTuple(
db *goqu.TxDatabase,
tx *goqu.TxDatabase,
index *uint64,
nodeID uint64,
clusterID uint64,
Expand All @@ -477,9 +522,39 @@ func saveInfoTuple(
return err
}

query := fmt.Sprintf(`
INSERT OR REPLACE INTO %s(entry_index, node_id, cluster_id, entry_type, payload)
VALUES(?, ?, ?, ?, ?);`, raftInfoTable)
_, err = db.Exec(query, index, nodeID, clusterID, entryType, data)
exps := []goqu.Expression{
goqu.C("node_id").Eq(nodeID),
goqu.C("cluster_id").Eq(clusterID),
goqu.C("entry_type").Eq(entryType),
}

logger := log.With().
Uint64("node_id", nodeID).
Uint64("cluster_id", clusterID).
Uint64("entry_type", uint64(entryType)).
Logger()

if index != nil {
exps = append(exps, goqu.C("entry_index").Eq(*index))
logger = logger.With().Uint64("index", *index).Logger()
} else {
exps = append(exps, goqu.C("entry_index").IsNull())
logger = logger.With().Int("index", -1).Logger()
}

_, err = tx.Delete(raftInfoTable).Where(exps...).Prepared(true).Executor().Exec()
if err != nil {
return err
}

_, err = tx.Insert(raftInfoTable).Rows(goqu.Record{
"node_id": nodeID,
"entry_index": index,
"cluster_id": clusterID,
"entry_type": entryType,
"payload": data,
}).Prepared(true).Executor().Exec()

logger.Trace().Err(err).Msg("Saved")
return err
}

0 comments on commit 8b5e508

Please sign in to comment.