From 6009d555bcd9175b51fbe8babd2ca924a39327e7 Mon Sep 17 00:00:00 2001 From: Zohaib Date: Sun, 9 Oct 2022 12:02:17 -0700 Subject: [PATCH] Major refactor in logic to reload connections appropriately after restore DB has happened --- db/sqlite.go | 131 ++++++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 58 deletions(-) diff --git a/db/sqlite.go b/db/sqlite.go index 4bba6f0..5c184b6 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -62,34 +62,21 @@ func GetAllDBTables(path string) ([]string, error) { } func OpenStreamDB(path string) (*SqliteStreamDB, error) { - connectionStr := fmt.Sprintf("%s?_journal_mode=wal", path) - conn, rawConn, err := OpenRaw(connectionStr) - if err != nil { - return nil, err - } - - conn.SetConnMaxLifetime(0) - conn.SetConnMaxIdleTime(10 * time.Second) - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } - - err = watcher.Add(path) - if err != nil { - return nil, err - } - ret := &SqliteStreamDB{ - Database: goqu.New("sqlite3", conn), - rawConnection: rawConn, - watcher: watcher, + Database: nil, + rawConnection: nil, + watcher: nil, dbPath: path, prefix: MarmotPrefix, publishLock: &sync.Mutex{}, watchTablesSchema: map[string][]*ColumnInfo{}, } + err := ret.cycleDBConnection() + if err != nil { + return nil, err + } + return ret, nil } @@ -261,7 +248,7 @@ func (conn *SqliteStreamDB) BackupTo(bkFilePath string) error { } func (conn *SqliteStreamDB) RestoreFrom(bkFilePath string) error { - dnsTpl := "%s?_journal_mode=wal&_foreign_keys=false&&_busy_timeout=30000&_txlock=%s" + dnsTpl := "%s?_journal_mode=wal&_foreign_keys=false&_busy_timeout=30000&_txlock=%s" dns := fmt.Sprintf(dnsTpl, conn.dbPath, snapshotTransactionMode) destDB, dest, err := OpenRaw(dns) if err != nil { @@ -284,22 +271,7 @@ func (conn *SqliteStreamDB) RestoreFrom(bkFilePath string) error { // else is modifying or interacting with DB err = sgSQL.WithTx(func(dtx *goqu.TxDatabase) error { return dgSQL.WithTx(func(_ *goqu.TxDatabase) error { - err := copyFile(conn.dbPath, bkFilePath) - if err != nil { - return err - } - - err = copyFile(conn.dbPath+"-wal", bkFilePath+"-wal") - if err != nil { - return err - } - - err = copyFile(conn.dbPath+"-shm", bkFilePath+"-shm") - if err != nil { - return err - } - - return nil + return copyFile(conn.dbPath, bkFilePath) }) }) @@ -307,27 +279,12 @@ func (conn *SqliteStreamDB) RestoreFrom(bkFilePath string) error { return err } - // Perform checkpoints to make sure database is full ready - row := dgSQL.QueryRow("PRAGMA wal_checkpoint(truncate);") - rBusy, rLog, rCheckpoint := int64(1), int64(0), int64(0) - for rBusy != 0 { - err = row.Scan(&rBusy, &rLog, &rCheckpoint) - if err != nil { - return err - } - - if rBusy != 0 { - log.Debug(). - Int64("busy", rBusy). - Int64("log", rLog). - Int64("checkpoint", rCheckpoint). - Msg("Waiting checkpoint...") - - time.Sleep(500 * time.Millisecond) - } + err = performCheckpoint(dgSQL) + if err != nil { + return err } - return nil + return conn.cycleDBConnection() } func (conn *SqliteStreamDB) GetRawConnection() *sqlite3.SQLiteConn { @@ -338,6 +295,41 @@ func (conn *SqliteStreamDB) GetPath() string { return conn.dbPath } +func (conn *SqliteStreamDB) cycleDBConnection() error { + connectionStr := fmt.Sprintf("%s?_journal_mode=wal", conn.dbPath) + sqlC, rawConn, err := OpenRaw(connectionStr) + if err != nil { + return err + } + + sqlC.SetConnMaxLifetime(0) + sqlC.SetConnMaxIdleTime(10 * time.Second) + + if conn.rawConnection != nil { + err = conn.rawConnection.Close() + } + + conn.Database = goqu.New("sql", sqlC) + conn.rawConnection = rawConn + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + + err = watcher.Add(conn.dbPath) + if err != nil { + return err + } + + if conn.watcher != nil { + conn.watcher.Close() + } + conn.watcher = watcher + + return performCheckpoint(conn.Database) +} + func copyFile(toPath, fromPath string) error { fi, err := os.OpenFile(fromPath, os.O_RDWR, 0) if err != nil { @@ -356,7 +348,7 @@ func copyFile(toPath, fromPath string) error { Int64("bytes", bytesWritten). Str("from", fromPath). Str("to", toPath). - Msg("copyFile") + Msg("File copied...") return err } @@ -373,3 +365,26 @@ func listDBTables(names *[]string, gSQL *goqu.TxDatabase) error { return nil } + +func performCheckpoint(gSQL *goqu.Database) error { + rBusy, rLog, rCheckpoint := int64(1), int64(0), int64(0) + for rBusy != 0 { + row := gSQL.QueryRow("PRAGMA wal_checkpoint(truncate);") + err := row.Scan(&rBusy, &rLog, &rCheckpoint) + if err != nil { + return err + } + + if rBusy != 0 { + log.Debug(). + Int64("busy", rBusy). + Int64("log", rLog). + Int64("checkpoint", rCheckpoint). + Msg("Waiting checkpoint...") + + time.Sleep(500 * time.Millisecond) + } + } + + return nil +}