diff --git a/model/po/review.go b/model/po/review.go index 9a1e130..9ed0330 100644 --- a/model/po/review.go +++ b/model/po/review.go @@ -8,7 +8,7 @@ type ReviewPO struct { UserID int64 `gorm:"index;index:uniq_course_review,unique"` Comment string Rating int64 `gorm:"index"` - Semester string `gorm:"index;index:uniq_course_review,unique"` + Semester string `gorm:"index"` IsAnonymous bool SearchIndex SearchIndex `gorm:"->:false;<-"` } diff --git a/repository/rating.go b/repository/rating.go index 6d9a40d..9850c16 100644 --- a/repository/rating.go +++ b/repository/rating.go @@ -62,7 +62,7 @@ func (r *RatingQuery) GetRatingInfo(ctx context.Context, relatedType model.Ratin res := model.RatingInfo{} dists := make([]model.RatingInfoDistItem, 0) db := r.optionDB(ctx) - err := db.Select("rating, count(id)"). + err := db.Debug().Select("rating, count(*) as count"). Where("related_type = ? and related_id = ?", relatedType, relatedID). Group("rating"). Find(&dists).Error diff --git a/repository/review.go b/repository/review.go index e657785..1da9143 100644 --- a/repository/review.go +++ b/repository/review.go @@ -4,6 +4,7 @@ import ( "context" "gorm.io/gorm" + "gorm.io/gorm/clause" "jcourse_go/model/converter" "jcourse_go/model/model" @@ -12,8 +13,7 @@ import ( type IReviewQuery interface { GetReviewCount(ctx context.Context, opts ...DBOption) (int64, error) - GetReviewDetail(ctx context.Context, opts ...DBOption) (*po.ReviewPO, error) - GetReviewList(ctx context.Context, opts ...DBOption) ([]po.ReviewPO, error) + GetReview(ctx context.Context, opts ...DBOption) ([]po.ReviewPO, error) CreateReview(ctx context.Context, review po.ReviewPO) (int64, error) UpdateReview(ctx context.Context, review po.ReviewPO) error DeleteReview(ctx context.Context, opts ...DBOption) error @@ -59,17 +59,7 @@ func (c *ReviewQuery) optionDB(ctx context.Context, opts ...DBOption) *gorm.DB { return db } -func (c *ReviewQuery) GetReviewDetail(ctx context.Context, opts ...DBOption) (*po.ReviewPO, error) { - db := c.optionDB(ctx, opts...) - review := po.ReviewPO{} - result := db.WithContext(ctx).First(&review) - if result.Error != nil { - return nil, result.Error - } - return &review, nil -} - -func (c *ReviewQuery) GetReviewList(ctx context.Context, opts ...DBOption) ([]po.ReviewPO, error) { +func (c *ReviewQuery) GetReview(ctx context.Context, opts ...DBOption) ([]po.ReviewPO, error) { db := c.optionDB(ctx, opts...) reviews := make([]po.ReviewPO, 0) result := db.WithContext(ctx).Find(&reviews) @@ -98,10 +88,12 @@ func (c *ReviewQuery) UpdateReview(ctx context.Context, review po.ReviewPO) erro db := c.db.WithContext(ctx) ratingPO := converter.BuildRatingFromReview(review) err := db.Transaction(func(tx *gorm.DB) error { - if err := tx.Model(&po.ReviewPO{}).Updates(&review).Error; err != nil { + if err := tx.Model(&po.ReviewPO{}).Where("id = ?", review.ID).Updates(&review).Error; err != nil { return err } - if err := tx.Model(&po.RatingPO{}).Updates(&ratingPO).Error; err != nil { + if err := tx.Model(&po.RatingPO{}). + Where("related_type = ? and related_id = ?", model.RelatedTypeCourse, review.ID). + Updates(&ratingPO).Error; err != nil { return err } return nil @@ -110,21 +102,22 @@ func (c *ReviewQuery) UpdateReview(ctx context.Context, review po.ReviewPO) erro } func (c *ReviewQuery) DeleteReview(ctx context.Context, opts ...DBOption) error { - db := c.optionDB(ctx, opts...) - reviews := make([]po.ReviewPO, 0) - err := db.Find(&reviews).Error - if err != nil { - return err - } - ids := make([]int64, 0) - for _, review := range reviews { - ids = append(ids, int64(review.ID)) - } - err = db.Transaction(func(tx *gorm.DB) error { - if err := tx.Delete(&po.ReviewPO{}).Error; err != nil { + err := c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + tx1 := tx.Session(&gorm.Session{}) + for _, opt := range opts { + tx1 = opt(tx1) + } + reviews := make([]po.ReviewPO, 0) + if err := tx1.Model(&po.ReviewPO{}).Clauses(clause.Returning{}).Delete(&reviews).Error; err != nil { return err } - if err := tx.Delete(&po.RatingPO{}, "related_id in ? and related_type = ?", ids, model.RelatedTypeCourse).Error; err != nil { + + ids := make([]int64, 0) + for _, review := range reviews { + ids = append(ids, int64(review.ID)) + } + tx2 := tx.Session(&gorm.Session{}) + if err := tx2.Model(&po.RatingPO{}).Delete(&po.RatingPO{}, "related_id in ? and related_type = ?", ids, model.RelatedTypeCourse).Error; err != nil { return err } return nil diff --git a/repository/review_test.go b/repository/review_test.go new file mode 100644 index 0000000..c3a3f76 --- /dev/null +++ b/repository/review_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + + "jcourse_go/dal" + "jcourse_go/model/model" + "jcourse_go/model/po" +) + +func TestReviewQuery_GetReview(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewReviewQuery(db) + + reviews, err := query.GetReview(ctx, WithID(1)) + assert.Nil(t, err) + assert.Len(t, reviews, 1) + assert.Equal(t, int64(5), reviews[0].Rating) +} + +func TestReviewQuery_GetReviewCount(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewReviewQuery(db) + + count, err := query.GetReviewCount(ctx) + assert.Nil(t, err) + assert.Equal(t, int64(2), count) +} + +func TestReviewQuery_CreateReview(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewReviewQuery(db) + ratingQuery := NewRatingQuery(db) + + t.Run("normal", func(t *testing.T) { + reviewPO := po.ReviewPO{ + CourseID: 2, + UserID: 2, + Comment: "", + Rating: 5, + Semester: "", + IsAnonymous: false, + } + id, err := query.CreateReview(ctx, reviewPO) + assert.Nil(t, err) + assert.NotZero(t, id) + + rating, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 2) + assert.Nil(t, err) + assert.Equal(t, int64(2), rating.Count) + assert.Equal(t, float64(5), rating.Average) + }) + t.Run("duplicated review", func(t *testing.T) { + reviewPO := po.ReviewPO{ + CourseID: 1, + UserID: 1, + Comment: "", + Rating: 0, + Semester: "", + IsAnonymous: false, + } + id, err := query.CreateReview(ctx, reviewPO) + assert.NotNil(t, err) + assert.Zero(t, id) + + rating, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) + assert.Nil(t, err) + assert.Equal(t, int64(1), rating.Count) + }) +} + +func TestReviewQuery_UpdateReview(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewReviewQuery(db) + ratingQuery := NewRatingQuery(db) + + t.Run("duplicate", func(t *testing.T) { + reviewPO := po.ReviewPO{ + Model: gorm.Model{ID: 1}, + CourseID: 2, + UserID: 1, + Comment: "", + Rating: 1, + Semester: "", + IsAnonymous: false, + } + + err := query.UpdateReview(ctx, reviewPO) + assert.NotNil(t, err) + + // no change + reviews, err := query.GetReview(ctx, WithID(1)) + assert.Nil(t, err) + assert.Equal(t, int64(5), reviews[0].Rating) + + info, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) + if err != nil { + return + } + assert.Equal(t, int64(5), info.RatingDist[0].Rating) + }) + + t.Run("normal", func(t *testing.T) { + reviewPO := po.ReviewPO{ + Model: gorm.Model{ID: 1}, + CourseID: 1, + UserID: 1, + Comment: "", + Rating: 1, + Semester: "", + IsAnonymous: false, + } + + err := query.UpdateReview(ctx, reviewPO) + assert.Nil(t, err) + reviews, err := query.GetReview(ctx, WithID(1)) + assert.Nil(t, err) + assert.Equal(t, int64(1), reviews[0].Rating) + + info, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) + if err != nil { + return + } + assert.Equal(t, int64(1), info.RatingDist[0].Rating) + }) + +} + +func TestReviewQuery_DeleteReview(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewReviewQuery(db) + ratingQuery := NewRatingQuery(db) + + t.Run("normal", func(t *testing.T) { + err := query.DeleteReview(ctx, WithID(1)) + assert.Nil(t, err) + reviews, err := query.GetReview(ctx, WithID(1)) + assert.Nil(t, err) + assert.Len(t, reviews, 0) + + info, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) + if err != nil { + return + } + assert.Len(t, info.RatingDist, 0) + }) +} diff --git a/repository/test_env.go b/repository/test_env.go index cd55b9d..7fe4c3b 100644 --- a/repository/test_env.go +++ b/repository/test_env.go @@ -5,6 +5,7 @@ import ( "gorm.io/gorm" + "jcourse_go/model/model" "jcourse_go/model/po" ) @@ -177,6 +178,68 @@ func createTrainingPlan(db *gorm.DB) error { return nil } +func createTestUser(db *gorm.DB) error { + users := []po.UserPO{ + { + Model: gorm.Model{ID: 1}, + Username: "test1", + Email: "test1@example.com", + }, + { + Model: gorm.Model{ID: 2}, + Username: "test2", + Email: "test2@example.com", + }, + { + Model: gorm.Model{ID: 3}, + Username: "test3", + Email: "test3@example.com", + }, + } + err := db.Create(&users).Error + return err +} + +func createTestReview(db *gorm.DB) error { + reviews := []po.ReviewPO{ + { + Model: gorm.Model{ID: 1}, + CourseID: 1, + UserID: 1, + Comment: "test review", + Rating: 5, + }, + { + Model: gorm.Model{ID: 2}, + CourseID: 2, + UserID: 1, + Comment: "test review", + Rating: 4, + }, + } + err := db.Create(&reviews).Error + if err != nil { + return err + } + ratings := []po.RatingPO{ + { + Rating: 5, + RelatedID: 1, + UserID: 1, + RelatedType: model.RelatedTypeCourse, + }, + { + Rating: 5, + RelatedID: 2, + UserID: 1, + RelatedType: model.RelatedTypeCourse, + }, + } + + err = db.Create(&ratings).Error + return err +} + func CreateTestEnv(ctx context.Context, db *gorm.DB) error { db = db.WithContext(ctx) createFunc := []func(db *gorm.DB) error{ @@ -185,11 +248,13 @@ func CreateTestEnv(ctx context.Context, db *gorm.DB) error { createTestCourse, createCourseCategories, createTrainingPlan, + createTestUser, + createTestReview, } for _, fn := range createFunc { err := fn(db) if err != nil { - return err + panic(err) } } return nil diff --git a/service/review.go b/service/review.go index d0b3c7e..701bccf 100644 --- a/service/review.go +++ b/service/review.go @@ -41,7 +41,7 @@ func buildReviewDBOptionFromFilter(query repository.IReviewQuery, filter model.R func GetReviewList(ctx context.Context, filter model.ReviewFilter) ([]model.Review, error) { reviewQuery := repository.NewReviewQuery(dal.GetDBClient()) opts := buildReviewDBOptionFromFilter(reviewQuery, filter) - reviewPOs, err := reviewQuery.GetReviewList(ctx, opts...) + reviewPOs, err := reviewQuery.GetReview(ctx, opts...) if err != nil { return nil, err }