diff --git a/backend/src/transactions/club.go b/backend/src/transactions/club.go index b239e6fdd..a670e7221 100644 --- a/backend/src/transactions/club.go +++ b/backend/src/transactions/club.go @@ -61,7 +61,7 @@ func CreateClub(db *gorm.DB, userId uuid.UUID, club models.Club) (*models.Club, return &club, nil } -func GetClub(db *gorm.DB, id uuid.UUID, preloads ...OptionalPreload) (*models.Club, *errors.Error) { +func GetClub(db *gorm.DB, id uuid.UUID, preloads ...OptionalQuery) (*models.Club, *errors.Error) { var club models.Club query := db diff --git a/backend/src/transactions/club_tag.go b/backend/src/transactions/club_tag.go index a818a5257..36ff31951 100644 --- a/backend/src/transactions/club_tag.go +++ b/backend/src/transactions/club_tag.go @@ -9,12 +9,12 @@ import ( ) func CreateClubTags(db *gorm.DB, id uuid.UUID, tags []models.Tag) ([]models.Tag, *errors.Error) { - user, err := GetClub(db, id) + user, err := GetClub(db, id, PreloadTag()) if err != nil { return nil, err } - if err := db.Model(&user).Association("Tag").Replace(tags); err != nil { + if err := db.Model(&user).Association("Tag").Append(tags); err != nil { return nil, &errors.FailedToUpdateUser } diff --git a/backend/src/transactions/preloaders.go b/backend/src/transactions/preloaders.go index 0ed46f12d..39f3c88c2 100644 --- a/backend/src/transactions/preloaders.go +++ b/backend/src/transactions/preloaders.go @@ -2,16 +2,22 @@ package transactions import "gorm.io/gorm" -type OptionalPreload func(*gorm.DB) *gorm.DB +type OptionalQuery func(*gorm.DB) *gorm.DB -func PreloadFollwer() OptionalPreload { +func PreloadFollwer() OptionalQuery { return func(db *gorm.DB) *gorm.DB { return db.Preload("Follower") } } -func PreloadMember() OptionalPreload { +func PreloadMember() OptionalQuery { return func(db *gorm.DB) *gorm.DB { return db.Preload("Member") } } + +func PreloadTag() OptionalQuery { + return func(db *gorm.DB) *gorm.DB { + return db.Preload("Tag") + } +} diff --git a/backend/src/transactions/user.go b/backend/src/transactions/user.go index 23c910fe5..dae4ecd26 100644 --- a/backend/src/transactions/user.go +++ b/backend/src/transactions/user.go @@ -42,7 +42,7 @@ func GetUsers(db *gorm.DB, limit int, offset int) ([]models.User, *errors.Error) return users, nil } -func GetUser(db *gorm.DB, id uuid.UUID, preloads ...OptionalPreload) (*models.User, *errors.Error) { +func GetUser(db *gorm.DB, id uuid.UUID, preloads ...OptionalQuery) (*models.User, *errors.Error) { var user models.User query := db diff --git a/backend/src/transactions/user_follower.go b/backend/src/transactions/user_follower.go index e16fc7541..77dde3c69 100644 --- a/backend/src/transactions/user_follower.go +++ b/backend/src/transactions/user_follower.go @@ -20,9 +20,7 @@ func CreateFollowing(db *gorm.DB, userId uuid.UUID, clubId uuid.UUID) *errors.Er return err } - user.Follower = append(user.Follower, *club) - - if err := db.Model(&user).Association("Follower").Replace(user.Follower); err != nil { + if err := db.Model(&user).Association("Follower").Append(club); err != nil { return &errors.FailedToUpdateUser } diff --git a/backend/src/transactions/user_member.go b/backend/src/transactions/user_member.go index 47687ce1c..6b079862c 100644 --- a/backend/src/transactions/user_member.go +++ b/backend/src/transactions/user_member.go @@ -20,9 +20,7 @@ func CreateMember(db *gorm.DB, userId uuid.UUID, clubId uuid.UUID) *errors.Error return err } - user.Member = append(user.Member, *club) - - if err := db.Model(&user).Association("Member").Replace(user.Member); err != nil { + if err := db.Model(&user).Association("Member").Append(club); err != nil { return &errors.FailedToUpdateUser } diff --git a/backend/src/transactions/user_tag.go b/backend/src/transactions/user_tag.go index ac746d15d..3df46a0b8 100644 --- a/backend/src/transactions/user_tag.go +++ b/backend/src/transactions/user_tag.go @@ -22,12 +22,12 @@ func GetUserTags(db *gorm.DB, id uuid.UUID) ([]models.Tag, *errors.Error) { } func CreateUserTags(db *gorm.DB, id uuid.UUID, tags []models.Tag) ([]models.Tag, *errors.Error) { - user, err := GetUser(db, id) + user, err := GetUser(db, id, PreloadTag()) if err != nil { return nil, err } - if err := db.Model(&user).Association("Tag").Replace(tags); err != nil { + if err := db.Model(&user).Association("Tag").Append(tags); err != nil { return nil, &errors.FailedToUpdateUser } diff --git a/backend/tests/api/user_member_test.go b/backend/tests/api/user_member_test.go index 028c5f43b..9f41b9e52 100644 --- a/backend/tests/api/user_member_test.go +++ b/backend/tests/api/user_member_test.go @@ -34,9 +34,9 @@ func TestCreateMembershipWorks(t *testing.T) { eaa.Assert.NilError(err) - eaa.Assert.Equal(1, len(user.Member)) + eaa.Assert.Equal(2, len(user.Member)) // SAC Super Club and the one just added - eaa.Assert.Equal(clubUUID, user.Member[0].ID) + eaa.Assert.Equal(clubUUID, user.Member[1].ID) // second club AKA the one just added var club models.Club @@ -128,7 +128,7 @@ func TestDeleteMembershipWorks(t *testing.T) { eaa.Assert.NilError(err) - eaa.Assert.Equal(0, len(user.Member)) + eaa.Assert.Equal(1, len(user.Member)) // SAC Super Club var club models.Club @@ -259,7 +259,7 @@ func TestGetMembershipWorks(t *testing.T) { eaa.Assert.NilError(err) - eaa.Assert.Equal(1, len(clubs)) + eaa.Assert.Equal(2, len(clubs)) // SAC Super Club and the one just added var dbClubs []models.Club @@ -267,7 +267,7 @@ func TestGetMembershipWorks(t *testing.T) { eaa.Assert.NilError(err) - eaa.Assert.Equal(1, len(clubs)) + eaa.Assert.Equal(2, len(clubs)) // SAC Super Club and the one just added }, }, ).Close()