From 85f9eb5337ee924f438bcd161e65c03da8006ff6 Mon Sep 17 00:00:00 2001 From: Zohaib Date: Sun, 18 Sep 2022 09:49:56 -0700 Subject: [PATCH] Finalizing save/restore for SQLiteStateMachine --- lib/raft.go | 13 +- lib/sqlite_log_db.go | 45 +++++- ...ica_machine.go => sqlite_state_machine.go} | 147 ++++++++++++------ lib/utils.go | 28 ++++ 4 files changed, 182 insertions(+), 51 deletions(-) rename lib/{replica_machine.go => sqlite_state_machine.go} (67%) create mode 100644 lib/utils.go diff --git a/lib/raft.go b/lib/raft.go index 6b51a40..e51d65e 100644 --- a/lib/raft.go +++ b/lib/raft.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/rand" + "os" "strconv" "strings" "sync" @@ -69,16 +70,19 @@ func (r *RaftServer) Init() error { defer r.lock.Unlock() metaAbsPath := fmt.Sprintf("%s/node-%d", r.metaPath, r.nodeID) - factory := NewSQLiteLogDBFactory(r.metaPath, r.nodeID) hostConfig := config.NodeHostConfig{ WALDir: metaAbsPath, NodeHostDir: metaAbsPath, RTTMillisecond: 300, RaftAddress: r.bindAddress, RaftEventListener: r, - Expert: config.ExpertConfig{ + } + + if strings.ToLower(os.Getenv("SQLITE_LOG_STORE")) == "true" { + factory := NewSQLiteLogDBFactory(r.metaPath, r.nodeID) + hostConfig.Expert = config.ExpertConfig{ LogDBFactory: factory, - }, + } } nodeHost, err := dragonboat.NewNodeHost(hostConfig) @@ -304,10 +308,9 @@ func (r *RaftServer) stateMachineFactory(clusterID uint64, nodeID uint64) statem r.lock.Lock() defer r.lock.Unlock() - firstNode := len(r.clusterStateMachine) == 0 sm, ok := r.clusterStateMachine[clusterID] if !ok { - sm = NewDBStateMachine(clusterID, nodeID, r.database, r.metaPath, firstNode) + sm = NewDBStateMachine(clusterID, nodeID, r.database, r.metaPath, clusterID == 1) r.clusterStateMachine[clusterID] = sm } diff --git a/lib/sqlite_log_db.go b/lib/sqlite_log_db.go index 3ace6e8..308142c 100644 --- a/lib/sqlite_log_db.go +++ b/lib/sqlite_log_db.go @@ -475,10 +475,53 @@ func (s *SQLiteLogDB) ListSnapshots(clusterID uint64, nodeID uint64, index uint6 func (s *SQLiteLogDB) ImportSnapshot(snp raftpb.Snapshot, nodeID uint64) error { return s.db.WithTx(func(tx *goqu.TxDatabase) error { - err := saveInfoTuple(tx, &snp.Index, nodeID, snp.ClusterId, Snapshot, snp.Marshal) + if raftpb.IsEmptySnapshot(snp) { + return nil + } + + // Replace Bootstrap + err := deleteInfoTuple(tx, nodeID, snp.ClusterId, Bootstrap, []goqu.Expression{}) + if err != nil { + return err + } + + bootstrap := raftpb.Bootstrap{ + Join: true, + Type: snp.Type, + } + err = saveInfoTuple(tx, nil, nodeID, snp.ClusterId, Bootstrap, bootstrap.Marshal) + if err != nil { + return err + } + + // Replace state + err = deleteInfoTuple(tx, nodeID, snp.ClusterId, State, []goqu.Expression{}) + if err != nil { + return err + } + + state := raftpb.State{ + Term: snp.Term, + Commit: snp.Index, + } + err = saveInfoTuple(tx, nil, nodeID, snp.ClusterId, State, state.Marshal) if err != nil { return err } + + // Delete snapshot log entries ahead of index + err = deleteInfoTuple(tx, nodeID, snp.ClusterId, Snapshot, []goqu.Expression{ + goqu.C("entry_index").Gte(snp.Index), + }) + if err != nil { + return err + } + + err = saveInfoTuple(tx, &snp.Index, nodeID, snp.ClusterId, Snapshot, snp.Marshal) + if err != nil { + return err + } + return nil }) } diff --git a/lib/replica_machine.go b/lib/sqlite_state_machine.go similarity index 67% rename from lib/replica_machine.go rename to lib/sqlite_state_machine.go index 232817e..ccce934 100644 --- a/lib/replica_machine.go +++ b/lib/sqlite_state_machine.go @@ -15,10 +15,15 @@ import ( type snapshotState = uint8 -type indexState struct { +type appliedIndexInfo struct { Index uint64 } +type stateSaveInfo struct { + appliedIndex appliedIndexInfo + dbPath string +} + type SQLiteStateMachine struct { NodeID uint64 ClusterID uint64 @@ -28,7 +33,7 @@ type SQLiteStateMachine struct { enableSnapshots bool snapshotLock *sync.Mutex snapshotState snapshotState - indexState *indexState + applied *appliedIndexInfo } type ReplicationEvent[T any] struct { @@ -64,8 +69,8 @@ func NewDBStateMachine( enableSnapshots: enableSnapshots, snapshotLock: &sync.Mutex{}, - snapshotState: 0, - indexState: &indexState{Index: 0}, + snapshotState: snapshotNotInitialized, + applied: &appliedIndexInfo{Index: 0}, } } @@ -75,7 +80,7 @@ func (ssm *SQLiteStateMachine) Open(_ <-chan struct{}) (uint64, error) { return 0, err } - return ssm.indexState.Index, nil + return ssm.applied.Index, nil } func (ssm *SQLiteStateMachine) Update(entries []sm.Entry) ([]sm.Entry, error) { @@ -98,12 +103,12 @@ func (ssm *SQLiteStateMachine) Update(entries []sm.Entry) ([]sm.Entry, error) { return nil, err } - ssm.indexState.Index = entry.Index + ssm.applied.Index = entry.Index if err := ssm.saveIndex(); err != nil { return nil, err } - entry.Result = sm.Result{Value: 0} + entry.Result = sm.Result{Value: entry.Index} } return entries, nil @@ -114,11 +119,16 @@ func (ssm *SQLiteStateMachine) Sync() error { } func (ssm *SQLiteStateMachine) PrepareSnapshot() (interface{}, error) { + log.Debug(). + Uint64("cluster", ssm.ClusterID). + Uint64("node", ssm.NodeID). + Bool("enabled", ssm.enableSnapshots). + Msg("Preparing snapshot...") + if !ssm.enableSnapshots { - return nil, nil + return stateSaveInfo{dbPath: "", appliedIndex: *ssm.applied}, nil } - log.Debug().Msg("PrepareSnapshot") bkFileDir, err := ssm.getSnapshotDir() if err != nil { return nil, err @@ -130,71 +140,116 @@ func (ssm *SQLiteStateMachine) PrepareSnapshot() (interface{}, error) { return nil, err } - return bkFilePath, nil + return stateSaveInfo{dbPath: bkFilePath, appliedIndex: *ssm.applied}, nil } -func (ssm *SQLiteStateMachine) SaveSnapshot(path interface{}, writer io.Writer, _ <-chan struct{}) error { - if !ssm.enableSnapshots { - return nil - } +func (ssm *SQLiteStateMachine) SaveSnapshot(st interface{}, writer io.Writer, _ <-chan struct{}) error { + log.Debug(). + Uint64("cluster", ssm.ClusterID). + Uint64("node", ssm.NodeID). + Bool("enabled", ssm.enableSnapshots). + Msg("Saving snapshot...") - ssm.snapshotLock.Lock() - defer ssm.snapshotLock.Unlock() - filepath, ok := path.(string) + stInfo, ok := st.(stateSaveInfo) if !ok { - return fmt.Errorf(fmt.Sprintf("invalid file path %v", path)) + return fmt.Errorf(fmt.Sprintf("invalid save state info %v", st)) } - fi, err := os.Open(filepath) + mBytes, err := cbor.Marshal(stInfo.appliedIndex) + err = writeUint32(writer, uint32(len(mBytes))) if err != nil { return err } - defer ssm.cleanup(fi, filepath) - _, err = io.Copy(writer, fi) + _, err = writer.Write(mBytes) if err != nil { return err } + // Write length of filepath as indicator for following up stream + err = writeUint32(writer, uint32(len(stInfo.dbPath))) + if err != nil { + return err + } + + if stInfo.dbPath != "" { + filepath := stInfo.dbPath + fi, err := os.Open(filepath) + if err != nil { + return err + } + defer ssm.cleanup(fi, filepath) + + _, err = io.Copy(writer, fi) + if err != nil { + return err + } + } + + ssm.snapshotLock.Lock() + defer ssm.snapshotLock.Unlock() ssm.snapshotState = snapshotSaved return nil } func (ssm *SQLiteStateMachine) RecoverFromSnapshot(reader io.Reader, _ <-chan struct{}) error { - if !ssm.enableSnapshots { - return nil - } - - log.Debug().Msg("RecoverFromSnapshot") - basePath, err := ssm.getSnapshotDir() + log.Debug(). + Uint64("cluster", ssm.ClusterID). + Uint64("node", ssm.NodeID). + Bool("enabled", ssm.enableSnapshots). + Msg("Recovering from snapshot...") + + appIndex := appliedIndexInfo{} + buffLen, err := readUint32(reader) if err != nil { return err } - filepath := path.Join(basePath, "restore.sqlite") - fo, err := os.OpenFile(filepath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) + dec := cbor.NewDecoder(io.LimitReader(reader, int64(buffLen))) + err = dec.Decode(&appIndex) if err != nil { return err } - defer ssm.cleanup(fo, filepath) - _, err = io.Copy(fo, reader) + hasData, err := readUint32(reader) if err != nil { return err } - // Flush file contents before handing off - err = fo.Sync() - if err != nil { - return err - } + ssm.snapshotLock.Lock() + defer ssm.snapshotLock.Unlock() + if hasData != 0 { + basePath, err := ssm.getSnapshotDir() + if err != nil { + return err + } - err = ssm.importSnapshot(filepath) - if err != nil { - return err + filepath := path.Join(basePath, "restore.sqlite") + fo, err := os.OpenFile(filepath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer ssm.cleanup(fo, filepath) + + _, err = io.Copy(fo, reader) + if err != nil { + return err + } + + // Flush file contents before handing off + err = fo.Sync() + if err != nil { + return err + } + + err = ssm.importSnapshot(filepath) + if err != nil { + return err + } } - return nil + ssm.applied = &appIndex + return ssm.saveIndex() } func (ssm *SQLiteStateMachine) Lookup(_ interface{}) (interface{}, error) { @@ -224,9 +279,6 @@ func (ssm *SQLiteStateMachine) Close() error { } func (ssm *SQLiteStateMachine) importSnapshot(filepath string) error { - ssm.snapshotLock.Lock() - defer ssm.snapshotLock.Unlock() - log.Info().Str("path", filepath).Msg("Importing...") err := ssm.DB.RestoreFrom(filepath) if err != nil { @@ -277,7 +329,7 @@ func (ssm *SQLiteStateMachine) saveIndex() error { return err } - b, err := cbor.Marshal(ssm.indexState) + b, err := cbor.Marshal(ssm.applied) if err != nil { return err } @@ -292,6 +344,11 @@ func (ssm *SQLiteStateMachine) saveIndex() error { return err } + log.Debug(). + Uint64("node_id", ssm.NodeID). + Uint64("cluster_id", ssm.ClusterID). + Uint64("index", ssm.applied.Index). + Msg("Saved index") return nil } @@ -316,7 +373,7 @@ func (ssm *SQLiteStateMachine) readIndex() error { return err } - err = cbor.Unmarshal(b, ssm.indexState) + err = cbor.Unmarshal(b, ssm.applied) if err != nil { return err } diff --git a/lib/utils.go b/lib/utils.go new file mode 100644 index 0000000..320a3f8 --- /dev/null +++ b/lib/utils.go @@ -0,0 +1,28 @@ +package lib + +import ( + "encoding/binary" + "io" +) + +func writeUint32(writer io.Writer, val uint32) error { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, val) + _, err := writer.Write(buf) + if err != nil { + return err + } + + return nil +} + +func readUint32(reader io.Reader) (uint32, error) { + buff := make([]byte, 4, 4) + _, err := io.LimitReader(reader, 4).Read(buff) + if err != nil { + return 0, err + } + + val := binary.BigEndian.Uint32(buff) + return val, nil +}