Skip to content

Commit

Permalink
test: add test for course
Browse files Browse the repository at this point in the history
  • Loading branch information
dujiajun committed Sep 21, 2024
1 parent 834c87f commit a8263f8
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 30 deletions.
35 changes: 15 additions & 20 deletions repository/course.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package repository
import (
"context"
"errors"
"fmt"

"gorm.io/gorm"

"jcourse_go/dal"
"jcourse_go/model/po"
)

Expand Down Expand Up @@ -101,13 +99,18 @@ func (c *CourseQuery) optionDB(ctx context.Context, opts ...DBOption) *gorm.DB {
}

func (c *CourseQuery) GetCourseCategories(ctx context.Context, courseIDs []int64) (map[int64][]string, error) {
courseCategoryMap := make(map[int64][]string)
for _, id := range courseIDs {
courseCategoryMap[id] = make([]string, 0)
}

db := c.db.WithContext(ctx).Model(po.CourseCategoryPO{})
courseCategoryPOs := make([]po.CourseCategoryPO, 0)
result := db.Where("course_id in ?", courseIDs).Find(&courseCategoryPOs)
if result.Error != nil {
return nil, result.Error
return courseCategoryMap, result.Error
}
courseCategoryMap := make(map[int64][]string)

for _, courseCategoryPO := range courseCategoryPOs {
categories, ok := courseCategoryMap[courseCategoryPO.CourseID]
if !ok {
Expand All @@ -122,9 +125,12 @@ func (c *CourseQuery) GetCourseCategories(ctx context.Context, courseIDs []int64
func (c *CourseQuery) GetCourse(ctx context.Context, opts ...DBOption) ([]po.CoursePO, error) {
db := c.optionDB(ctx, opts...)
coursePOs := make([]po.CoursePO, 0)
result := db.Find(&coursePOs)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
err := db.Find(&coursePOs).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return coursePOs, nil
}
if err != nil {
return coursePOs, err
}
return coursePOs, nil
}
Expand All @@ -139,8 +145,8 @@ func (c *CourseQuery) GetCourseCount(ctx context.Context, opts ...DBOption) (int
return count, nil
}

func NewCourseQuery() ICourseQuery {
return &CourseQuery{db: dal.GetDBClient()}
func NewCourseQuery(db *gorm.DB) ICourseQuery {
return &CourseQuery{db: db}
}

type IOfferedCourseQuery interface {
Expand All @@ -154,17 +160,6 @@ type OfferedCourseQuery struct {
db *gorm.DB
}

func (o *OfferedCourseQuery) WithOrderBy(field string, ascending bool) DBOption {
return func(db *gorm.DB) *gorm.DB {
if ascending {
field = fmt.Sprintf("%s %s", field, "asc")
} else {
field = fmt.Sprintf("%s %s", field, "desc")
}
return db.Order(field)
}
}

func (o *OfferedCourseQuery) GetOfferedCourseTeacherGroup(ctx context.Context, offeredCourseIDs []int64) (map[int64][]po.TeacherPO, error) {
db := o.db.WithContext(ctx).Model(&po.OfferedCourseTeacherPO{})
courseTeacherPOs := make([]po.OfferedCourseTeacherPO, 0)
Expand Down
92 changes: 89 additions & 3 deletions repository/course_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,101 @@ func TestBaseCourseQuery_GetBaseCourseCount(t *testing.T) {
})
}

func TestCourseQuery_GetCourseByIDs(t *testing.T) {
func TestBaseCourseQuery_GetCourseByIDs(t *testing.T) {
ctx := context.Background()
db := dal.GetDBClient()
query := NewBaseCourseQuery(db)
courseMap, err := query.GetBaseCoursesByIDs(ctx, []int64{1, 2})
t.Run("all", func(t *testing.T) {
courseMap, err := query.GetBaseCoursesByIDs(ctx, []int64{1, 2})
assert.Nil(t, err)
assert.Len(t, courseMap, 2)
assert.Equal(t, uint(1), courseMap[1].ID)
assert.Equal(t, "MARX1001", courseMap[1].Code)
assert.Equal(t, uint(2), courseMap[2].ID)
assert.Equal(t, "CS1500", courseMap[2].Code)
})
t.Run("not found", func(t *testing.T) {
courseMap, err := query.GetBaseCoursesByIDs(ctx, []int64{10, 20})
assert.Nil(t, err)
assert.Len(t, courseMap, 0)
})
}

func TestCourseQuery_GetCourse(t *testing.T) {
ctx := context.Background()
db := dal.GetDBClient()
query := NewCourseQuery(db)
t.Run("all", func(t *testing.T) {
courses, err := query.GetCourse(ctx)
assert.Nil(t, err)
assert.Len(t, courses, 4)
})
t.Run("filter course", func(t *testing.T) {
courses, err := query.GetCourse(ctx, WithCode("CS1500"))
assert.Nil(t, err)
assert.Len(t, courses, 1)
assert.Equal(t, "CS1500", courses[0].Code)
})
t.Run("none", func(t *testing.T) {
courses, err := query.GetCourse(ctx, WithCode("CS3500"))
assert.Nil(t, err)
assert.Len(t, courses, 0)
})
}

func TestCourseQuery_GetCourseCount(t *testing.T) {
ctx := context.Background()
db := dal.GetDBClient()
query := NewCourseQuery(db)
t.Run("all", func(t *testing.T) {
count, err := query.GetCourseCount(ctx)
assert.Nil(t, err)
assert.Equal(t, int64(4), count)
})
t.Run("filter course", func(t *testing.T) {
count, err := query.GetCourseCount(ctx, WithCode("CS1500"))
assert.Nil(t, err)
assert.Equal(t, int64(1), count)
})
t.Run("none", func(t *testing.T) {
count, err := query.GetCourseCount(ctx, WithCode("CS3500"))
assert.Nil(t, err)
assert.Equal(t, int64(0), count)
})
}

func TestCourseQuery_GetCourseByIDs(t *testing.T) {
ctx := context.Background()
db := dal.GetDBClient()
query := NewCourseQuery(db)
courseMap, err := query.GetCourseByIDs(ctx, []int64{1, 2})
assert.Nil(t, err)
assert.Len(t, courseMap, 2)
assert.Equal(t, uint(1), courseMap[1].ID)
assert.Equal(t, "MARX1001", courseMap[1].Code)
assert.Equal(t, uint(2), courseMap[2].ID)
assert.Equal(t, "CS1500", courseMap[2].Code)
assert.Equal(t, "MARX1001", courseMap[2].Code)
}

func TestCourseQuery_GetCourseCategories(t *testing.T) {
ctx := context.Background()
db := dal.GetDBClient()
query := NewCourseQuery(db)

t.Run("no category", func(t *testing.T) {
categories, err := query.GetCourseCategories(ctx, []int64{3})
assert.Nil(t, err)
assert.Len(t, categories, 1)
assert.Len(t, categories[3], 0)
})

t.Run("has category", func(t *testing.T) {
categories, err := query.GetCourseCategories(ctx, []int64{2})
assert.Nil(t, err)
assert.Len(t, categories, 1)
category := categories[2]
assert.Len(t, category, 2)
assert.Contains(t, category, "通识")
assert.Contains(t, category, "必修")
})
}
122 changes: 119 additions & 3 deletions repository/test_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,127 @@ func createTestBaseCourse(db *gorm.DB) error {
return err
}

func createTestTeacher(db *gorm.DB) error {
teachers := []po.TeacherPO{
{
Model: gorm.Model{ID: 1},
Code: "10001",
Name: "高女士",
Email: "[email protected]",
Department: "SEIEE",
Pinyin: "gaoxiaofeng",
PinyinAbbr: "gxf",
Title: "教授",
},
{
Model: gorm.Model{ID: 2},
Code: "10002",
Name: "潘老师",
Email: "[email protected]",
Department: "SEIEE",
Pinyin: "panli",
PinyinAbbr: "pl",
Title: "教授",
},
{
Model: gorm.Model{ID: 3},
Code: "10003",
Name: "梁女士",
Email: "[email protected]",
Department: "PHYSICS",
Pinyin: "liangqin",
PinyinAbbr: "lq",
Title: "教授",
},
{
Model: gorm.Model{ID: 4},
Code: "10004",
Name: "赵先生",
Email: "[email protected]",
Pinyin: "zhaohao",
PinyinAbbr: "zh",
Title: "讲师",
},
}
err := db.Create(&teachers).Error
return err
}

func createTestCourse(db *gorm.DB) error {
courses := []po.CoursePO{
{
Model: gorm.Model{ID: 1},
Code: "MARX1001",
Name: "思想道德修养与法律基础",
Credit: 3,
MainTeacherID: 3,
MainTeacherName: "梁女士",
Department: "MARX",
},
{
Model: gorm.Model{ID: 2},
Code: "MARX1001",
Name: "思想道德修养与法律基础",
Credit: 3,
MainTeacherID: 4,
MainTeacherName: "赵先生",
Department: "MARX",
},
{
Model: gorm.Model{ID: 3},
Code: "CS1500",
Name: "计算机科学导论",
Credit: 3,
MainTeacherID: 1,
MainTeacherName: "高女士",
Department: "SEIEE",
},
{
Model: gorm.Model{ID: 4},
Code: "CS2500",
Name: "算法与复杂性",
Credit: 3,
MainTeacherID: 1,
MainTeacherName: "高女士",
Department: "SEIEE",
},
}
err := db.Create(&courses).Error
return err
}

func createCourseCategories(db *gorm.DB) error {
categories := []po.CourseCategoryPO{
{
CourseID: 1,
Category: "通识",
},
{
CourseID: 2,
Category: "通识",
},
{
CourseID: 2,
Category: "必修",
},
}
err := db.Create(&categories).Error
return err
}

func CreateTestEnv(ctx context.Context, db *gorm.DB) error {
db = db.WithContext(ctx)
err := createTestBaseCourse(db)
if err != nil {
return err
createFunc := []func(db *gorm.DB) error{
createTestBaseCourse,
createTestTeacher,
createTestCourse,
createCourseCategories,
}
for _, fn := range createFunc {
err := fn(db)
if err != nil {
return err
}
}
return nil
}
8 changes: 4 additions & 4 deletions service/course.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func GetCourseDetail(ctx context.Context, courseID int64) (*model.CourseDetail,
if courseID == 0 {
return nil, errors.New("course id is 0")
}
courseQuery := repository.NewCourseQuery()
courseQuery := repository.NewCourseQuery(dal.GetDBClient())
coursePOs, err := courseQuery.GetCourse(ctx, repository.WithID(courseID))
if err != nil || len(coursePOs) == 0 {
return nil, err
Expand Down Expand Up @@ -78,7 +78,7 @@ func buildCourseDBOptionFromFilter(query repository.ICourseQuery, filter model.C
}

func GetCourseList(ctx context.Context, filter model.CourseListFilter) ([]model.CourseSummary, error) {
query := repository.NewCourseQuery()
query := repository.NewCourseQuery(dal.GetDBClient())
opts := buildCourseDBOptionFromFilter(query, filter)

coursePOs, err := query.GetCourse(ctx, opts...)
Expand Down Expand Up @@ -113,7 +113,7 @@ func GetCourseList(ctx context.Context, filter model.CourseListFilter) ([]model.
}

func GetCourseCount(ctx context.Context, filter model.CourseListFilter) (int64, error) {
query := repository.NewCourseQuery()
query := repository.NewCourseQuery(dal.GetDBClient())
filter.Page, filter.PageSize = 0, 0
opts := buildCourseDBOptionFromFilter(query, filter)
return query.GetCourseCount(ctx, opts...)
Expand All @@ -124,7 +124,7 @@ func GetCourseByIDs(ctx context.Context, courseIDs []int64) (map[int64]model.Cou
if len(courseIDs) == 0 {
return result, nil
}
courseQuery := repository.NewCourseQuery()
courseQuery := repository.NewCourseQuery(dal.GetDBClient())
courseMap, err := courseQuery.GetCourseByIDs(ctx, courseIDs)
if err != nil {
return nil, err
Expand Down

0 comments on commit a8263f8

Please sign in to comment.