From 98c777e418b9784293f80353467ec1a26779af47 Mon Sep 17 00:00:00 2001 From: Du Jiajun Date: Sat, 21 Sep 2024 20:33:35 +0800 Subject: [PATCH] test: add test for rating --- model/model/rating.go | 4 +- repository/rating.go | 15 ++++++- repository/rating_test.go | 95 +++++++++++++++++++++++++++++++++++++++ repository/review_test.go | 23 ++-------- repository/test_env.go | 6 +++ 5 files changed, 122 insertions(+), 21 deletions(-) create mode 100644 repository/rating_test.go diff --git a/model/model/rating.go b/model/model/rating.go index b4d3293..d8c5730 100644 --- a/model/model/rating.go +++ b/model/model/rating.go @@ -3,7 +3,9 @@ package model type RatingRelatedType = string const ( - RelatedTypeCourse RatingRelatedType = "course" + RelatedTypeCourse RatingRelatedType = "course" + RelatedTypeTeacher RatingRelatedType = "teacher" + RelatedTypeTrainingPlan RatingRelatedType = "training_plan" ) type RatingInfoDistItemByID struct { diff --git a/repository/rating.go b/repository/rating.go index 9850c16..b649ebe 100644 --- a/repository/rating.go +++ b/repository/rating.go @@ -13,12 +13,24 @@ type IRatingQuery interface { GetRatingInfo(ctx context.Context, relatedType model.RatingRelatedType, relatedID int64) (model.RatingInfo, error) GetRatingInfoByIDs(ctx context.Context, relatedType model.RatingRelatedType, relatedIDs []int64) (map[int64]model.RatingInfo, error) CreateRating(ctx context.Context, ratingPO po.RatingPO) error + UpdateRating(ctx context.Context, ratingPO po.RatingPO) error + DeleteRating(ctx context.Context, ratingPO po.RatingPO) error } type RatingQuery struct { db *gorm.DB } +func (r *RatingQuery) UpdateRating(ctx context.Context, ratingPO po.RatingPO) error { + db := r.optionDB(ctx) + return db.Where("user_id = ? and related_id = ? and related_type = ?", ratingPO.UserID, ratingPO.RelatedID, ratingPO.RelatedType).Updates(&ratingPO).Error +} + +func (r *RatingQuery) DeleteRating(ctx context.Context, ratingPO po.RatingPO) error { + db := r.optionDB(ctx) + return db.Where("user_id = ? and related_type = ? and related_id = ?", ratingPO.UserID, ratingPO.RelatedType, ratingPO.RelatedID).Delete(&po.RatingPO{}).Error +} + func (r *RatingQuery) CreateRating(ctx context.Context, ratingPO po.RatingPO) error { db := r.optionDB(ctx) err := db.Create(&ratingPO).Error @@ -32,7 +44,7 @@ func (r *RatingQuery) GetRatingInfoByIDs(ctx context.Context, relatedType model. res := make(map[int64]model.RatingInfo) distByIDs := make([]model.RatingInfoDistItemByID, 0) db := r.optionDB(ctx) - err := db.Select("rating, count(id)"). + err := db.Select("rating, count(id) as count, related_id"). Where("related_type = ? and related_id in ?", relatedType, relatedIDs). Group("rating").Group("related_id"). Find(&distByIDs).Error @@ -45,6 +57,7 @@ func (r *RatingQuery) GetRatingInfoByIDs(ctx context.Context, relatedType model. info = model.RatingInfo{RatingDist: make([]model.RatingInfoDistItem, 0)} } info.RatingDist = append(info.RatingDist, model.RatingInfoDistItem{Rating: dist.Rating, Count: dist.Count}) + info.Calc() res[dist.RelatedID] = info } return res, nil diff --git a/repository/rating_test.go b/repository/rating_test.go new file mode 100644 index 0000000..43ed8ba --- /dev/null +++ b/repository/rating_test.go @@ -0,0 +1,95 @@ +package repository + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "jcourse_go/dal" + "jcourse_go/model/model" + "jcourse_go/model/po" +) + +func TestRatingQuery_GetRatingInfo(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewRatingQuery(db) + info, err := query.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) + assert.Nil(t, err) + assert.Len(t, info.RatingDist, 1) + assert.Equal(t, int64(1), info.Count) + assert.Equal(t, float64(5), info.Average) +} + +func TestRatingQuery_GetRatingInfoByIDs(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewRatingQuery(db) + infoMap, err := query.GetRatingInfoByIDs(ctx, model.RelatedTypeCourse, []int64{1}) + assert.Nil(t, err) + assert.Len(t, infoMap, 1) + info := infoMap[1] + assert.Len(t, info.RatingDist, 1) + assert.Equal(t, int64(1), info.Count) + assert.Equal(t, float64(5), info.Average) +} + +func TestRatingQuery_CreateRating(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewRatingQuery(db) + rating := po.RatingPO{ + UserID: 3, + RelatedType: model.RelatedTypeTeacher, + RelatedID: 1, + Rating: 5, + } + err := query.CreateRating(ctx, rating) + assert.Nil(t, err) + + info, err := query.GetRatingInfo(ctx, model.RelatedTypeTeacher, 1) + assert.Nil(t, err) + assert.Len(t, info.RatingDist, 1) + assert.Equal(t, int64(1), info.Count) + assert.Equal(t, float64(5), info.Average) +} + +func TestRatingQuery_UpdateRating(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewRatingQuery(db) + rating := po.RatingPO{ + UserID: 1, + RelatedType: model.RelatedTypeTeacher, + RelatedID: 2, + Rating: 3, + } + err := query.UpdateRating(ctx, rating) + assert.Nil(t, err) + + info, err := query.GetRatingInfo(ctx, model.RelatedTypeTeacher, 2) + assert.Nil(t, err) + assert.Len(t, info.RatingDist, 1) + assert.Equal(t, int64(1), info.Count) + assert.Equal(t, float64(3), info.Average) +} + +func TestRatingQuery_DeleteRating(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewRatingQuery(db) + rating := po.RatingPO{ + UserID: 1, + RelatedType: model.RelatedTypeTeacher, + RelatedID: 2, + Rating: 3, + } + err := query.DeleteRating(ctx, rating) + assert.Nil(t, err) + + info, err := query.GetRatingInfo(ctx, model.RelatedTypeTeacher, 2) + assert.Nil(t, err) + assert.Len(t, info.RatingDist, 0) + assert.Equal(t, int64(0), info.Count) +} diff --git a/repository/review_test.go b/repository/review_test.go index c3a3f76..0df96bd 100644 --- a/repository/review_test.go +++ b/repository/review_test.go @@ -33,7 +33,7 @@ func TestReviewQuery_GetReviewCount(t *testing.T) { assert.Equal(t, int64(2), count) } -func TestReviewQuery_CreateReview(t *testing.T) { +func TestReviewQuery_CreateReview_normal(t *testing.T) { ctx := context.Background() db := dal.GetDBClient() query := NewReviewQuery(db) @@ -57,23 +57,7 @@ func TestReviewQuery_CreateReview(t *testing.T) { assert.Equal(t, int64(2), rating.Count) assert.Equal(t, float64(5), rating.Average) }) - t.Run("duplicated review", func(t *testing.T) { - reviewPO := po.ReviewPO{ - CourseID: 1, - UserID: 1, - Comment: "", - Rating: 0, - Semester: "", - IsAnonymous: false, - } - id, err := query.CreateReview(ctx, reviewPO) - assert.NotNil(t, err) - assert.Zero(t, id) - rating, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) - assert.Nil(t, err) - assert.Equal(t, int64(1), rating.Count) - }) } func TestReviewQuery_UpdateReview(t *testing.T) { @@ -105,13 +89,14 @@ func TestReviewQuery_UpdateReview(t *testing.T) { if err != nil { return } + assert.Len(t, info.RatingDist, 1) assert.Equal(t, int64(5), info.RatingDist[0].Rating) }) t.Run("normal", func(t *testing.T) { reviewPO := po.ReviewPO{ Model: gorm.Model{ID: 1}, - CourseID: 1, + CourseID: 3, UserID: 1, Comment: "", Rating: 1, @@ -125,7 +110,7 @@ func TestReviewQuery_UpdateReview(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(1), reviews[0].Rating) - info, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 1) + info, err := ratingQuery.GetRatingInfo(ctx, model.RelatedTypeCourse, 3) if err != nil { return } diff --git a/repository/test_env.go b/repository/test_env.go index 7fe4c3b..ee445b5 100644 --- a/repository/test_env.go +++ b/repository/test_env.go @@ -234,6 +234,12 @@ func createTestReview(db *gorm.DB) error { UserID: 1, RelatedType: model.RelatedTypeCourse, }, + { + Rating: 5, + RelatedID: 2, + UserID: 1, + RelatedType: model.RelatedTypeTeacher, + }, } err = db.Create(&ratings).Error