diff --git a/repository/course.go b/repository/course.go index bd2009f..8476de3 100644 --- a/repository/course.go +++ b/repository/course.go @@ -3,11 +3,9 @@ package repository import ( "context" "errors" - "fmt" "gorm.io/gorm" - "jcourse_go/dal" "jcourse_go/model/po" ) @@ -101,13 +99,18 @@ func (c *CourseQuery) optionDB(ctx context.Context, opts ...DBOption) *gorm.DB { } func (c *CourseQuery) GetCourseCategories(ctx context.Context, courseIDs []int64) (map[int64][]string, error) { + courseCategoryMap := make(map[int64][]string) + for _, id := range courseIDs { + courseCategoryMap[id] = make([]string, 0) + } + db := c.db.WithContext(ctx).Model(po.CourseCategoryPO{}) courseCategoryPOs := make([]po.CourseCategoryPO, 0) result := db.Where("course_id in ?", courseIDs).Find(&courseCategoryPOs) if result.Error != nil { - return nil, result.Error + return courseCategoryMap, result.Error } - courseCategoryMap := make(map[int64][]string) + for _, courseCategoryPO := range courseCategoryPOs { categories, ok := courseCategoryMap[courseCategoryPO.CourseID] if !ok { @@ -122,9 +125,12 @@ func (c *CourseQuery) GetCourseCategories(ctx context.Context, courseIDs []int64 func (c *CourseQuery) GetCourse(ctx context.Context, opts ...DBOption) ([]po.CoursePO, error) { db := c.optionDB(ctx, opts...) coursePOs := make([]po.CoursePO, 0) - result := db.Find(&coursePOs) - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, nil + err := db.Find(&coursePOs).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return coursePOs, nil + } + if err != nil { + return coursePOs, err } return coursePOs, nil } @@ -139,8 +145,8 @@ func (c *CourseQuery) GetCourseCount(ctx context.Context, opts ...DBOption) (int return count, nil } -func NewCourseQuery() ICourseQuery { - return &CourseQuery{db: dal.GetDBClient()} +func NewCourseQuery(db *gorm.DB) ICourseQuery { + return &CourseQuery{db: db} } type IOfferedCourseQuery interface { @@ -154,17 +160,6 @@ type OfferedCourseQuery struct { db *gorm.DB } -func (o *OfferedCourseQuery) WithOrderBy(field string, ascending bool) DBOption { - return func(db *gorm.DB) *gorm.DB { - if ascending { - field = fmt.Sprintf("%s %s", field, "asc") - } else { - field = fmt.Sprintf("%s %s", field, "desc") - } - return db.Order(field) - } -} - func (o *OfferedCourseQuery) GetOfferedCourseTeacherGroup(ctx context.Context, offeredCourseIDs []int64) (map[int64][]po.TeacherPO, error) { db := o.db.WithContext(ctx).Model(&po.OfferedCourseTeacherPO{}) courseTeacherPOs := make([]po.OfferedCourseTeacherPO, 0) diff --git a/repository/course_test.go b/repository/course_test.go index 5cdcdc5..86c29f7 100644 --- a/repository/course_test.go +++ b/repository/course_test.go @@ -57,15 +57,101 @@ func TestBaseCourseQuery_GetBaseCourseCount(t *testing.T) { }) } -func TestCourseQuery_GetCourseByIDs(t *testing.T) { +func TestBaseCourseQuery_GetCourseByIDs(t *testing.T) { ctx := context.Background() db := dal.GetDBClient() query := NewBaseCourseQuery(db) - courseMap, err := query.GetBaseCoursesByIDs(ctx, []int64{1, 2}) + t.Run("all", func(t *testing.T) { + courseMap, err := query.GetBaseCoursesByIDs(ctx, []int64{1, 2}) + assert.Nil(t, err) + assert.Len(t, courseMap, 2) + assert.Equal(t, uint(1), courseMap[1].ID) + assert.Equal(t, "MARX1001", courseMap[1].Code) + assert.Equal(t, uint(2), courseMap[2].ID) + assert.Equal(t, "CS1500", courseMap[2].Code) + }) + t.Run("not found", func(t *testing.T) { + courseMap, err := query.GetBaseCoursesByIDs(ctx, []int64{10, 20}) + assert.Nil(t, err) + assert.Len(t, courseMap, 0) + }) +} + +func TestCourseQuery_GetCourse(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewCourseQuery(db) + t.Run("all", func(t *testing.T) { + courses, err := query.GetCourse(ctx) + assert.Nil(t, err) + assert.Len(t, courses, 4) + }) + t.Run("filter course", func(t *testing.T) { + courses, err := query.GetCourse(ctx, WithCode("CS1500")) + assert.Nil(t, err) + assert.Len(t, courses, 1) + assert.Equal(t, "CS1500", courses[0].Code) + }) + t.Run("none", func(t *testing.T) { + courses, err := query.GetCourse(ctx, WithCode("CS3500")) + assert.Nil(t, err) + assert.Len(t, courses, 0) + }) +} + +func TestCourseQuery_GetCourseCount(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewCourseQuery(db) + t.Run("all", func(t *testing.T) { + count, err := query.GetCourseCount(ctx) + assert.Nil(t, err) + assert.Equal(t, int64(4), count) + }) + t.Run("filter course", func(t *testing.T) { + count, err := query.GetCourseCount(ctx, WithCode("CS1500")) + assert.Nil(t, err) + assert.Equal(t, int64(1), count) + }) + t.Run("none", func(t *testing.T) { + count, err := query.GetCourseCount(ctx, WithCode("CS3500")) + assert.Nil(t, err) + assert.Equal(t, int64(0), count) + }) +} + +func TestCourseQuery_GetCourseByIDs(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewCourseQuery(db) + courseMap, err := query.GetCourseByIDs(ctx, []int64{1, 2}) assert.Nil(t, err) assert.Len(t, courseMap, 2) assert.Equal(t, uint(1), courseMap[1].ID) assert.Equal(t, "MARX1001", courseMap[1].Code) assert.Equal(t, uint(2), courseMap[2].ID) - assert.Equal(t, "CS1500", courseMap[2].Code) + assert.Equal(t, "MARX1001", courseMap[2].Code) +} + +func TestCourseQuery_GetCourseCategories(t *testing.T) { + ctx := context.Background() + db := dal.GetDBClient() + query := NewCourseQuery(db) + + t.Run("no category", func(t *testing.T) { + categories, err := query.GetCourseCategories(ctx, []int64{3}) + assert.Nil(t, err) + assert.Len(t, categories, 1) + assert.Len(t, categories[3], 0) + }) + + t.Run("has category", func(t *testing.T) { + categories, err := query.GetCourseCategories(ctx, []int64{2}) + assert.Nil(t, err) + assert.Len(t, categories, 1) + category := categories[2] + assert.Len(t, category, 2) + assert.Contains(t, category, "通识") + assert.Contains(t, category, "必修") + }) } diff --git a/repository/test_env.go b/repository/test_env.go index 4d32a03..460a56e 100644 --- a/repository/test_env.go +++ b/repository/test_env.go @@ -32,11 +32,127 @@ func createTestBaseCourse(db *gorm.DB) error { return err } +func createTestTeacher(db *gorm.DB) error { + teachers := []po.TeacherPO{ + { + Model: gorm.Model{ID: 1}, + Code: "10001", + Name: "高女士", + Email: "gaoxiaofeng@example.com", + Department: "SEIEE", + Pinyin: "gaoxiaofeng", + PinyinAbbr: "gxf", + Title: "教授", + }, + { + Model: gorm.Model{ID: 2}, + Code: "10002", + Name: "潘老师", + Email: "panli@example.com", + Department: "SEIEE", + Pinyin: "panli", + PinyinAbbr: "pl", + Title: "教授", + }, + { + Model: gorm.Model{ID: 3}, + Code: "10003", + Name: "梁女士", + Email: "liangqin@example.com", + Department: "PHYSICS", + Pinyin: "liangqin", + PinyinAbbr: "lq", + Title: "教授", + }, + { + Model: gorm.Model{ID: 4}, + Code: "10004", + Name: "赵先生", + Email: "zhaohao@example.com", + Pinyin: "zhaohao", + PinyinAbbr: "zh", + Title: "讲师", + }, + } + err := db.Create(&teachers).Error + return err +} + +func createTestCourse(db *gorm.DB) error { + courses := []po.CoursePO{ + { + Model: gorm.Model{ID: 1}, + Code: "MARX1001", + Name: "思想道德修养与法律基础", + Credit: 3, + MainTeacherID: 3, + MainTeacherName: "梁女士", + Department: "MARX", + }, + { + Model: gorm.Model{ID: 2}, + Code: "MARX1001", + Name: "思想道德修养与法律基础", + Credit: 3, + MainTeacherID: 4, + MainTeacherName: "赵先生", + Department: "MARX", + }, + { + Model: gorm.Model{ID: 3}, + Code: "CS1500", + Name: "计算机科学导论", + Credit: 3, + MainTeacherID: 1, + MainTeacherName: "高女士", + Department: "SEIEE", + }, + { + Model: gorm.Model{ID: 4}, + Code: "CS2500", + Name: "算法与复杂性", + Credit: 3, + MainTeacherID: 1, + MainTeacherName: "高女士", + Department: "SEIEE", + }, + } + err := db.Create(&courses).Error + return err +} + +func createCourseCategories(db *gorm.DB) error { + categories := []po.CourseCategoryPO{ + { + CourseID: 1, + Category: "通识", + }, + { + CourseID: 2, + Category: "通识", + }, + { + CourseID: 2, + Category: "必修", + }, + } + err := db.Create(&categories).Error + return err +} + func CreateTestEnv(ctx context.Context, db *gorm.DB) error { db = db.WithContext(ctx) - err := createTestBaseCourse(db) - if err != nil { - return err + createFunc := []func(db *gorm.DB) error{ + createTestBaseCourse, + createTestTeacher, + createTestCourse, + createCourseCategories, + } + for _, fn := range createFunc { + err := fn(db) + if err != nil { + return err + } } return nil } diff --git a/service/course.go b/service/course.go index 1426ab0..e4d108f 100644 --- a/service/course.go +++ b/service/course.go @@ -15,7 +15,7 @@ func GetCourseDetail(ctx context.Context, courseID int64) (*model.CourseDetail, if courseID == 0 { return nil, errors.New("course id is 0") } - courseQuery := repository.NewCourseQuery() + courseQuery := repository.NewCourseQuery(dal.GetDBClient()) coursePOs, err := courseQuery.GetCourse(ctx, repository.WithID(courseID)) if err != nil || len(coursePOs) == 0 { return nil, err @@ -78,7 +78,7 @@ func buildCourseDBOptionFromFilter(query repository.ICourseQuery, filter model.C } func GetCourseList(ctx context.Context, filter model.CourseListFilter) ([]model.CourseSummary, error) { - query := repository.NewCourseQuery() + query := repository.NewCourseQuery(dal.GetDBClient()) opts := buildCourseDBOptionFromFilter(query, filter) coursePOs, err := query.GetCourse(ctx, opts...) @@ -113,7 +113,7 @@ func GetCourseList(ctx context.Context, filter model.CourseListFilter) ([]model. } func GetCourseCount(ctx context.Context, filter model.CourseListFilter) (int64, error) { - query := repository.NewCourseQuery() + query := repository.NewCourseQuery(dal.GetDBClient()) filter.Page, filter.PageSize = 0, 0 opts := buildCourseDBOptionFromFilter(query, filter) return query.GetCourseCount(ctx, opts...) @@ -124,7 +124,7 @@ func GetCourseByIDs(ctx context.Context, courseIDs []int64) (map[int64]model.Cou if len(courseIDs) == 0 { return result, nil } - courseQuery := repository.NewCourseQuery() + courseQuery := repository.NewCourseQuery(dal.GetDBClient()) courseMap, err := courseQuery.GetCourseByIDs(ctx, courseIDs) if err != nil { return nil, err