Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DB scripts #106

Merged
merged 18 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 21 additions & 75 deletions scripts/create_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,47 @@ import (
"errors"
"fmt"

"github.com/source-academy/stories-backend/internal/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

func connectAnonDB(conf config.DatabaseConfig) (*gorm.DB, error) {
conf.DatabaseName = ""
dsn := conf.ToDataSourceName()
return connectDBHelper(dsn)
}

func connectDB(conf config.DatabaseConfig) (*gorm.DB, error) {
dsn := conf.ToDataSourceName()
return connectDBHelper(dsn)
}

func connectDBHelper(dsn string) (*gorm.DB, error) {
driver := postgres.Open(dsn)

db, err := gorm.Open(driver, &gorm.Config{})
if err != nil {
return nil, err
}

dbName, err := getConnectedDBName(db)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Connected to database", dbName+".")

return db, nil
}

func closeDBConnection(d *gorm.DB) {
db, err := d.DB()
if err != nil {
panic(err)
}

dbName, err := getConnectedDBName(d)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Closing connection with database", dbName+".")

if err := db.Close(); err != nil {
panic(err)
}
}

func createDB(db *gorm.DB, dbconf *config.DatabaseConfig) error {
if dbconf.DatabaseName == "" {
func createDB(db *gorm.DB, dbName string) error {
if dbName == "" {
return errors.New("Failed to create database: no database name provided.")
}

// check if db exists
fmt.Println(yellowChevron, "Checking if database", dbconf.DatabaseName, "exists.")
result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbconf.DatabaseName)
fmt.Println(yellowChevron, "Checking if database", dbName, "exists.")
result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbName)
if result.Error != nil {
return result.Error
}

// if not exists create it
rec := make(map[string]interface{})
if result.Find(rec); len(rec) == 0 {
fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "does not exist. Creating...")

create_command := fmt.Sprintf("CREATE DATABASE %s", dbconf.DatabaseName)
result := db.Exec(create_command)

if result.Error != nil {
return result.Error
}
if result.Find(rec); len(rec) != 0 {
fmt.Println(greenTick, "Database", dbName, "already exists.")
return nil
}

fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "exists.")
fmt.Println(yellowChevron, "Database", dbName, "does not exist. Creating...")
create_command := fmt.Sprintf("CREATE DATABASE %s", dbName)
err := db.Exec(create_command).Error
if err != nil {
return err
}

fmt.Println(greenTick, "Created database:", dbName)
return nil
}

func dropDB(db *gorm.DB, dbconf *config.DatabaseConfig) error {
drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbconf.DatabaseName)
result := db.Exec(drop_command)
if result.Error != nil {
return result.Error
}

return nil
func dropDB(db *gorm.DB, dbName string) error {
drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName)
err := db.Exec(drop_command).Error
return err
}

func getConnectedDBName(db *gorm.DB) (string, error) {
func getDBName(db *gorm.DB) (string, error) {
var dbName string
result := db.Raw("SELECT current_database();").Scan(&dbName)
if result.Error != nil {
return "", result.Error
}
return dbName, nil
err := db.Raw("SELECT current_database();").Scan(&dbName).Error
return dbName, err
}
69 changes: 48 additions & 21 deletions scripts/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ import (
migrate "github.com/rubenv/sql-migrate"
"github.com/sirupsen/logrus"
"github.com/source-academy/stories-backend/internal/config"
"gorm.io/gorm"
"github.com/source-academy/stories-backend/internal/database"
)

const (
defaultMaxMigrateSteps = 0 // no limit
defaultMaxRollbackSteps = 1

dropCmd = "drop"
createCmd = "create"
migrateCmd = "migrate"
rollbackCmd = "rollback"
statusCmd = "status"
)

var (
Expand All @@ -28,67 +34,88 @@ var (
})
)

func main() {
func setupScript() (string, *config.DatabaseConfig) {
// Load configuration
conf, err := config.LoadFromEnvironment()
if err != nil {
logrus.Errorln(err)
panic(err)
}

var connector func(config.DatabaseConfig) (*gorm.DB, error)
targetDBName := conf.Database.DatabaseName

// Check for command line arguments
flag.Parse()
switch flag.Arg(0) {
case "drop", "create":
connector = connectAnonDB
case "migrate", "rollback", "status":
connector = connectDB
case dropCmd, createCmd:
// We need to connect anonymously in order
// to drop or create the database.
conf.Database.DatabaseName = ""
case migrateCmd, rollbackCmd, statusCmd:
// Do nothing
default:
logrus.Errorln("Invalid command")
return targetDBName, nil
}

return targetDBName, conf.Database
}

func main() {
targetDBName, dbConfig := setupScript()
if dbConfig == nil {
// Invalid configuration
return
}

// Connect to the database
d, err := connector(*conf.Database)
dbConn, err := database.Connect(dbConfig)
if err != nil {
logrus.Errorln(err)
panic(err)
}
defer closeDBConnection(d)
// Remember to close the connection
defer (func() {
fmt.Println(blueSandwich, "Closing connection...")
database.Close(dbConn)
})()

dbName, err := getDBName(dbConn)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Connected to database", dbName+".")

switch flag.Arg(0) {
case "drop":
err := dropDB(d, conf.Database)
case dropCmd:
err := dropDB(dbConn, targetDBName)
if err != nil {
logrus.Errorln(err)
panic(err)
}
fmt.Println(greenTick, "Dropped database:", conf.Database.DatabaseName)
case "create":
err := createDB(d, conf.Database)
fmt.Println(greenTick, "Dropped database:", targetDBName)
case createCmd:
err := createDB(dbConn, targetDBName)
if err != nil {
logrus.Errorln(err)
panic(err)
}
fmt.Println(greenTick, "Created database:", conf.Database.DatabaseName)
case "migrate":
db, err := d.DB()
case migrateCmd:
db, err := dbConn.DB()
if err != nil {
logrus.Errorln(err)
panic(err)
}
migrateDB(db)
case "rollback":
db, err := d.DB()
case rollbackCmd:
db, err := dbConn.DB()
if err != nil {
logrus.Errorln(err)
panic(err)
}
rollbackDB(db)
case "status":
db, err := d.DB()
case statusCmd:
db, err := dbConn.DB()
if err != nil {
logrus.Errorln(err)
panic(err)
Expand Down
Loading