Skip to content

Commit

Permalink
Major refactor in logic to reload connections appropriately after res…
Browse files Browse the repository at this point in the history
…tore DB has happened
  • Loading branch information
maxpert committed Oct 9, 2022
1 parent 3de891a commit 6009d55
Showing 1 changed file with 73 additions and 58 deletions.
131 changes: 73 additions & 58 deletions db/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,21 @@ func GetAllDBTables(path string) ([]string, error) {
}

func OpenStreamDB(path string) (*SqliteStreamDB, error) {
connectionStr := fmt.Sprintf("%s?_journal_mode=wal", path)
conn, rawConn, err := OpenRaw(connectionStr)
if err != nil {
return nil, err
}

conn.SetConnMaxLifetime(0)
conn.SetConnMaxIdleTime(10 * time.Second)
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}

err = watcher.Add(path)
if err != nil {
return nil, err
}

ret := &SqliteStreamDB{
Database: goqu.New("sqlite3", conn),
rawConnection: rawConn,
watcher: watcher,
Database: nil,
rawConnection: nil,
watcher: nil,
dbPath: path,
prefix: MarmotPrefix,
publishLock: &sync.Mutex{},
watchTablesSchema: map[string][]*ColumnInfo{},
}

err := ret.cycleDBConnection()
if err != nil {
return nil, err
}

return ret, nil
}

Expand Down Expand Up @@ -261,7 +248,7 @@ func (conn *SqliteStreamDB) BackupTo(bkFilePath string) error {
}

func (conn *SqliteStreamDB) RestoreFrom(bkFilePath string) error {
dnsTpl := "%s?_journal_mode=wal&_foreign_keys=false&&_busy_timeout=30000&_txlock=%s"
dnsTpl := "%s?_journal_mode=wal&_foreign_keys=false&_busy_timeout=30000&_txlock=%s"
dns := fmt.Sprintf(dnsTpl, conn.dbPath, snapshotTransactionMode)
destDB, dest, err := OpenRaw(dns)
if err != nil {
Expand All @@ -284,50 +271,20 @@ func (conn *SqliteStreamDB) RestoreFrom(bkFilePath string) error {
// else is modifying or interacting with DB
err = sgSQL.WithTx(func(dtx *goqu.TxDatabase) error {
return dgSQL.WithTx(func(_ *goqu.TxDatabase) error {
err := copyFile(conn.dbPath, bkFilePath)
if err != nil {
return err
}

err = copyFile(conn.dbPath+"-wal", bkFilePath+"-wal")
if err != nil {
return err
}

err = copyFile(conn.dbPath+"-shm", bkFilePath+"-shm")
if err != nil {
return err
}

return nil
return copyFile(conn.dbPath, bkFilePath)
})
})

if err != nil {
return err
}

// Perform checkpoints to make sure database is full ready
row := dgSQL.QueryRow("PRAGMA wal_checkpoint(truncate);")
rBusy, rLog, rCheckpoint := int64(1), int64(0), int64(0)
for rBusy != 0 {
err = row.Scan(&rBusy, &rLog, &rCheckpoint)
if err != nil {
return err
}

if rBusy != 0 {
log.Debug().
Int64("busy", rBusy).
Int64("log", rLog).
Int64("checkpoint", rCheckpoint).
Msg("Waiting checkpoint...")

time.Sleep(500 * time.Millisecond)
}
err = performCheckpoint(dgSQL)
if err != nil {
return err
}

return nil
return conn.cycleDBConnection()
}

func (conn *SqliteStreamDB) GetRawConnection() *sqlite3.SQLiteConn {
Expand All @@ -338,6 +295,41 @@ func (conn *SqliteStreamDB) GetPath() string {
return conn.dbPath
}

func (conn *SqliteStreamDB) cycleDBConnection() error {
connectionStr := fmt.Sprintf("%s?_journal_mode=wal", conn.dbPath)
sqlC, rawConn, err := OpenRaw(connectionStr)
if err != nil {
return err
}

sqlC.SetConnMaxLifetime(0)
sqlC.SetConnMaxIdleTime(10 * time.Second)

if conn.rawConnection != nil {
err = conn.rawConnection.Close()
}

conn.Database = goqu.New("sql", sqlC)
conn.rawConnection = rawConn

watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}

err = watcher.Add(conn.dbPath)
if err != nil {
return err
}

if conn.watcher != nil {
conn.watcher.Close()
}
conn.watcher = watcher

return performCheckpoint(conn.Database)
}

func copyFile(toPath, fromPath string) error {
fi, err := os.OpenFile(fromPath, os.O_RDWR, 0)
if err != nil {
Expand All @@ -356,7 +348,7 @@ func copyFile(toPath, fromPath string) error {
Int64("bytes", bytesWritten).
Str("from", fromPath).
Str("to", toPath).
Msg("copyFile")
Msg("File copied...")
return err
}

Expand All @@ -373,3 +365,26 @@ func listDBTables(names *[]string, gSQL *goqu.TxDatabase) error {

return nil
}

func performCheckpoint(gSQL *goqu.Database) error {
rBusy, rLog, rCheckpoint := int64(1), int64(0), int64(0)
for rBusy != 0 {
row := gSQL.QueryRow("PRAGMA wal_checkpoint(truncate);")
err := row.Scan(&rBusy, &rLog, &rCheckpoint)
if err != nil {
return err
}

if rBusy != 0 {
log.Debug().
Int64("busy", rBusy).
Int64("log", rLog).
Int64("checkpoint", rCheckpoint).
Msg("Waiting checkpoint...")

time.Sleep(500 * time.Millisecond)
}
}

return nil
}

0 comments on commit 6009d55

Please sign in to comment.