Skip to content

Commit

Permalink
Refactor DB setup Part 1 (#105)
Browse files Browse the repository at this point in the history
* Separate CreateStory DB operation to own function

Improves testability for unit testing.

* Add sample create DB test

* Separate DeleteStory DB operation to own function

* Separate CreateUser DB operation to own function

* Add TODO

* Separate DeleteUser DB operation to own function

* Fix pointer errors
  • Loading branch information
RichDom2185 authored Feb 12, 2024
1 parent f369819 commit a603c40
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
32 changes: 21 additions & 11 deletions model/stories.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,20 @@ func GetStoryByID(db *gorm.DB, id int) (Story, error) {
return story, nil
}

func CreateStory(db *gorm.DB, story *Story) error {
err := db.
func (s *Story) create(tx *gorm.DB) *gorm.DB {
return tx.
Preload(clause.Associations).
Create(story).
Create(s).
// Get associated Author. See
// https://github.com/go-gorm/gen/issues/618 on why
// a separate .First() is needed.
First(story).
Error
First(s)
}

func CreateStory(db *gorm.DB, story *Story) error {
err := db.Transaction(func(tx *gorm.DB) error {
return story.create(tx).Error
})
if err != nil {
return database.HandleDBError(err, "story")
}
Expand Down Expand Up @@ -109,14 +114,19 @@ func UpdateStory(db *gorm.DB, storyID int, newStory *Story) error {
return nil
}

func DeleteStory(db *gorm.DB, storyID int) (Story, error) {
var story Story
err := db.
func (s *Story) delete(tx *gorm.DB, storyID uint) *gorm.DB {
return tx.
Preload(clause.Associations).
Where("id = ?", storyID).
First(&story). // store the value to be returned
Delete(&story).
Error
First(s). // store the value to be returned
Delete(s)
}

func DeleteStory(db *gorm.DB, storyID int) (Story, error) {
var story Story
err := db.Transaction(func(tx *gorm.DB) error {
return story.delete(tx, uint(storyID)).Error
})
if err != nil {
return story, database.HandleDBError(err, "story")
}
Expand Down
24 changes: 22 additions & 2 deletions model/stories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
userenums "github.com/source-academy/stories-backend/internal/enums/users"
"github.com/source-academy/stories-backend/internal/testutils"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

// FIXME: Coupling with the other operations in the users database
Expand Down Expand Up @@ -40,6 +41,25 @@ var (
}
)

func TestCreate(t *testing.T) {
t.Run("", func(t *testing.T) {
db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath)
defer cleanUp(t)

// Any number is fine because the statement is not executed,
// thus removing the coupling with an actual author having to be
// created prior.
story := &Story{
AuthorID: 1,
Content: "The quick brown test content 5678.",
}
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return story.create(tx)
})
assert.Contains(t, sql, "The quick brown test content 5678.", "Should contain the story content")
})
}

func TestCreateStory(t *testing.T) {
t.Run("should increase the total story count", func(t *testing.T) {
db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath)
Expand Down Expand Up @@ -221,8 +241,8 @@ func TestStoryDB(t *testing.T) {
}
_ = CreateGroup(db, &group)

err := db.Exec(`INSERT INTO "stories"
("created_at","updated_at","deleted_at","author_id","group_id","title","content","pin_order")
err := db.Exec(`INSERT INTO "stories"
("created_at","updated_at","deleted_at","author_id","group_id","title","content","pin_order")
VALUES ('2023-08-08 22:17:28.085','2023-08-08 22:17:28.085',NULL,NULL,NULL,'','# Hi, This is a test story.',NULL)`).
Error
var pgerr *pgconn.PgError
Expand Down
29 changes: 21 additions & 8 deletions model/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,24 @@ func GetUserByID(db *gorm.DB, id int) (User, error) {
return user, err
}

func CreateUser(db *gorm.DB, user *User) error {
func (u *User) create(tx *gorm.DB) *gorm.DB {
// TODO: If user already exists, but is soft-deleted, undelete the user
err := db.Create(user).Error
return tx.Create(u)
}

func CreateUser(db *gorm.DB, user *User) error {
err := db.Transaction(func(tx *gorm.DB) error {
return user.create(tx).Error
})
if err != nil {
return database.HandleDBError(err, "user")
}
return nil
}

func CreateUsers(db *gorm.DB, users *[]*User) (int64, error) {
// TODO: Use users.create() instead
// Blocked by `RowsAffected` not being accessible.
tx := db.Create(users)
rowCount := tx.RowsAffected
if err := tx.Error; err != nil {
Expand All @@ -49,14 +57,19 @@ func CreateUsers(db *gorm.DB, users *[]*User) (int64, error) {
return rowCount, nil
}

func (u *User) delete(tx *gorm.DB, userID uint) *gorm.DB {
return tx.
Model(u).
Where("id = ?", userID).
First(u). // store the value to be returned
Delete(u)
}

func DeleteUser(db *gorm.DB, userID int) (User, error) {
var user User
err := db.
Model(&user).
Where("id = ?", userID).
First(&user). // store the value to be returned
Delete(&user).
Error
err := db.Transaction(func(tx *gorm.DB) error {
return user.delete(tx, uint(userID)).Error
})
if err != nil {
return user, database.HandleDBError(err, "user")
}
Expand Down

0 comments on commit a603c40

Please sign in to comment.