Skip to content

Commit

Permalink
Refactor DB scripts (#106)
Browse files Browse the repository at this point in the history
* Access error directly via variable

* Remove unnecesary 'if' clause

The default value of a string variable is already an empty string, and
there was no additional error handling done.

* Move printing logic to main script

Allows for better abstraction and improves reusability.

* Refactor to use common DB "connect" helper

Reduces code duplication.

* Refactor to use common DB "close" helper

* Rename `d` to `dbConn`

Makes the name more descriptive.

* Add comment

* Replace string with constants

Prevents bugs.

* Rename `getConnectedDBName` to `getDBName`

* Change createDB signature

Hides implementation details of the config object as the function only
needs to be aware of the database name to create, not how it is derived
from.

* Change dropDB signature

Similar to createDB, this improves abstraction.

* Refactor script setup to separate function

Improves single responsibility.

* Simplify dropDB implementation

* Refactor createDB implementation

* Move happy path to outside if condition
* Simplify error handling
* Update logging behavior

* Update log colors

* Remove unnecessary code
  • Loading branch information
RichDom2185 authored Feb 12, 2024
1 parent dd2f349 commit f369819
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 96 deletions.
96 changes: 21 additions & 75 deletions scripts/create_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,47 @@ import (
"errors"
"fmt"

"github.com/source-academy/stories-backend/internal/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

func connectAnonDB(conf config.DatabaseConfig) (*gorm.DB, error) {
conf.DatabaseName = ""
dsn := conf.ToDataSourceName()
return connectDBHelper(dsn)
}

func connectDB(conf config.DatabaseConfig) (*gorm.DB, error) {
dsn := conf.ToDataSourceName()
return connectDBHelper(dsn)
}

func connectDBHelper(dsn string) (*gorm.DB, error) {
driver := postgres.Open(dsn)

db, err := gorm.Open(driver, &gorm.Config{})
if err != nil {
return nil, err
}

dbName, err := getConnectedDBName(db)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Connected to database", dbName+".")

return db, nil
}

func closeDBConnection(d *gorm.DB) {
db, err := d.DB()
if err != nil {
panic(err)
}

dbName, err := getConnectedDBName(d)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Closing connection with database", dbName+".")

if err := db.Close(); err != nil {
panic(err)
}
}

func createDB(db *gorm.DB, dbconf *config.DatabaseConfig) error {
if dbconf.DatabaseName == "" {
func createDB(db *gorm.DB, dbName string) error {
if dbName == "" {
return errors.New("Failed to create database: no database name provided.")
}

// check if db exists
fmt.Println(yellowChevron, "Checking if database", dbconf.DatabaseName, "exists.")
result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbconf.DatabaseName)
fmt.Println(yellowChevron, "Checking if database", dbName, "exists.")
result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbName)
if result.Error != nil {
return result.Error
}

// if not exists create it
rec := make(map[string]interface{})
if result.Find(rec); len(rec) == 0 {
fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "does not exist. Creating...")

create_command := fmt.Sprintf("CREATE DATABASE %s", dbconf.DatabaseName)
result := db.Exec(create_command)

if result.Error != nil {
return result.Error
}
if result.Find(rec); len(rec) != 0 {
fmt.Println(greenTick, "Database", dbName, "already exists.")
return nil
}

fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "exists.")
fmt.Println(yellowChevron, "Database", dbName, "does not exist. Creating...")
create_command := fmt.Sprintf("CREATE DATABASE %s", dbName)
err := db.Exec(create_command).Error
if err != nil {
return err
}

fmt.Println(greenTick, "Created database:", dbName)
return nil
}

func dropDB(db *gorm.DB, dbconf *config.DatabaseConfig) error {
drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbconf.DatabaseName)
result := db.Exec(drop_command)
if result.Error != nil {
return result.Error
}

return nil
func dropDB(db *gorm.DB, dbName string) error {
drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName)
err := db.Exec(drop_command).Error
return err
}

func getConnectedDBName(db *gorm.DB) (string, error) {
func getDBName(db *gorm.DB) (string, error) {
var dbName string
result := db.Raw("SELECT current_database();").Scan(&dbName)
if result.Error != nil {
return "", result.Error
}
return dbName, nil
err := db.Raw("SELECT current_database();").Scan(&dbName).Error
return dbName, err
}
69 changes: 48 additions & 21 deletions scripts/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ import (
migrate "github.com/rubenv/sql-migrate"
"github.com/sirupsen/logrus"
"github.com/source-academy/stories-backend/internal/config"
"gorm.io/gorm"
"github.com/source-academy/stories-backend/internal/database"
)

const (
defaultMaxMigrateSteps = 0 // no limit
defaultMaxRollbackSteps = 1

dropCmd = "drop"
createCmd = "create"
migrateCmd = "migrate"
rollbackCmd = "rollback"
statusCmd = "status"
)

var (
Expand All @@ -28,67 +34,88 @@ var (
})
)

func main() {
func setupScript() (string, *config.DatabaseConfig) {
// Load configuration
conf, err := config.LoadFromEnvironment()
if err != nil {
logrus.Errorln(err)
panic(err)
}

var connector func(config.DatabaseConfig) (*gorm.DB, error)
targetDBName := conf.Database.DatabaseName

// Check for command line arguments
flag.Parse()
switch flag.Arg(0) {
case "drop", "create":
connector = connectAnonDB
case "migrate", "rollback", "status":
connector = connectDB
case dropCmd, createCmd:
// We need to connect anonymously in order
// to drop or create the database.
conf.Database.DatabaseName = ""
case migrateCmd, rollbackCmd, statusCmd:
// Do nothing
default:
logrus.Errorln("Invalid command")
return targetDBName, nil
}

return targetDBName, conf.Database
}

func main() {
targetDBName, dbConfig := setupScript()
if dbConfig == nil {
// Invalid configuration
return
}

// Connect to the database
d, err := connector(*conf.Database)
dbConn, err := database.Connect(dbConfig)
if err != nil {
logrus.Errorln(err)
panic(err)
}
defer closeDBConnection(d)
// Remember to close the connection
defer (func() {
fmt.Println(blueSandwich, "Closing connection...")
database.Close(dbConn)
})()

dbName, err := getDBName(dbConn)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Connected to database", dbName+".")

switch flag.Arg(0) {
case "drop":
err := dropDB(d, conf.Database)
case dropCmd:
err := dropDB(dbConn, targetDBName)
if err != nil {
logrus.Errorln(err)
panic(err)
}
fmt.Println(greenTick, "Dropped database:", conf.Database.DatabaseName)
case "create":
err := createDB(d, conf.Database)
fmt.Println(greenTick, "Dropped database:", targetDBName)
case createCmd:
err := createDB(dbConn, targetDBName)
if err != nil {
logrus.Errorln(err)
panic(err)
}
fmt.Println(greenTick, "Created database:", conf.Database.DatabaseName)
case "migrate":
db, err := d.DB()
case migrateCmd:
db, err := dbConn.DB()
if err != nil {
logrus.Errorln(err)
panic(err)
}
migrateDB(db)
case "rollback":
db, err := d.DB()
case rollbackCmd:
db, err := dbConn.DB()
if err != nil {
logrus.Errorln(err)
panic(err)
}
rollbackDB(db)
case "status":
db, err := d.DB()
case statusCmd:
db, err := dbConn.DB()
if err != nil {
logrus.Errorln(err)
panic(err)
Expand Down

0 comments on commit f369819

Please sign in to comment.