Skip to content

Commit

Permalink
refactor: repo DI
Browse files Browse the repository at this point in the history
  • Loading branch information
dujiajun committed Sep 10, 2024
1 parent e797407 commit f0e6e9c
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 55 deletions.
8 changes: 4 additions & 4 deletions repository/course.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func (b *BaseCourseQuery) GetBaseCourseList(ctx context.Context, opts ...DBOptio
return coursePOs, nil
}

func NewBaseCourseQuery() IBaseCourseQuery {
return &BaseCourseQuery{db: dal.GetDBClient()}
func NewBaseCourseQuery(db *gorm.DB) IBaseCourseQuery {
return &BaseCourseQuery{db: db}
}

type ICourseQuery interface {
Expand Down Expand Up @@ -238,6 +238,6 @@ func (o *OfferedCourseQuery) GetMainTeacherIDsWithOfferedCourseIDs(ctx context.C
return mainTeacherIDs, nil
}

func NewOfferedCourseQuery() IOfferedCourseQuery {
return &OfferedCourseQuery{db: dal.GetDBClient()}
func NewOfferedCourseQuery(db *gorm.DB) IOfferedCourseQuery {
return &OfferedCourseQuery{db: db}
}
5 changes: 2 additions & 3 deletions repository/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"gorm.io/gorm"

"jcourse_go/dal"
"jcourse_go/model/po"
)

Expand Down Expand Up @@ -95,6 +94,6 @@ func (c *ReviewQuery) DeleteReview(ctx context.Context, opts ...DBOption) (int64
return result.RowsAffected, result.Error
}

func NewReviewQuery() IReviewQuery {
return &ReviewQuery{db: dal.GetDBClient()}
func NewReviewQuery(db *gorm.DB) IReviewQuery {
return &ReviewQuery{db: db}
}
5 changes: 2 additions & 3 deletions repository/teacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"gorm.io/gorm"

"jcourse_go/dal"
"jcourse_go/model/po"
)

Expand All @@ -21,9 +20,9 @@ type TeacherQuery struct {
db *gorm.DB
}

func NewTeacherQuery() ITeacherQuery {
func NewTeacherQuery(db *gorm.DB) ITeacherQuery {
return &TeacherQuery{
db: dal.GetDBClient(),
db: db,
}
}

Expand Down
9 changes: 4 additions & 5 deletions repository/trainingplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package repository
import (
"context"

"jcourse_go/dal"
"jcourse_go/model/po"

"gorm.io/gorm"
Expand All @@ -21,8 +20,8 @@ type ITrainingPlanQuery interface {
GetTrainingPlanCount(ctx context.Context, opts ...DBOption) int64
}

func NewTrainingPlanQuery() ITrainingPlanQuery {
return &TrainingPlanQuery{db: dal.GetDBClient()}
func NewTrainingPlanQuery(db *gorm.DB) ITrainingPlanQuery {
return &TrainingPlanQuery{db: db}
}

func (t *TrainingPlanQuery) GetTrainingPlanListIDs(ctx context.Context, opts ...DBOption) ([]int64, error) {
Expand Down Expand Up @@ -76,8 +75,8 @@ type TrainingPlanCourseQuery struct {
db *gorm.DB
}

func NewTrainingPlanCourseQuery() ITrainingPlanCourseQuery {
return &TrainingPlanCourseQuery{db: dal.GetDBClient()}
func NewTrainingPlanCourseQuery(db *gorm.DB) ITrainingPlanCourseQuery {
return &TrainingPlanCourseQuery{db: db}
}

func (t *TrainingPlanCourseQuery) optionDB(ctx context.Context, opts ...DBOption) *gorm.DB {
Expand Down
9 changes: 4 additions & 5 deletions repository/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"gorm.io/gorm"

"jcourse_go/constant"
"jcourse_go/dal"
"jcourse_go/model/po"
)

Expand Down Expand Up @@ -68,13 +67,13 @@ func (u *UserProfileQuery) optionDB(ctx context.Context, opts ...DBOption) *gorm
return db
}

func NewUserProfileQuery() IUserProfileQuery {
return &UserProfileQuery{db: dal.GetDBClient()}
func NewUserProfileQuery(db *gorm.DB) IUserProfileQuery {
return &UserProfileQuery{db: db}
}

func NewUserQuery() IUserQuery {
func NewUserQuery(db *gorm.DB) IUserQuery {
return &UserQuery{
db: dal.GetDBClient(),
db: db,
}
}

Expand Down
7 changes: 4 additions & 3 deletions service/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/SJTU-jCourse/password_hasher"

"jcourse_go/constant"
"jcourse_go/dal"
"jcourse_go/model/converter"
"jcourse_go/model/domain"
"jcourse_go/repository"
Expand All @@ -22,7 +23,7 @@ func Login(ctx context.Context, email string, password string) (*domain.User, er
if err != nil {
return nil, err
}
query := repository.NewUserQuery()
query := repository.NewUserQuery(dal.GetDBClient())
userPO, err := query.GetUserDetail(ctx, repository.WithEmail(email), repository.WithPassword(passwordStore))
if err != nil {
return nil, err
Expand All @@ -39,7 +40,7 @@ func Register(ctx context.Context, email string, password string, code string) (
if storedCode != code {
return nil, errors.New("verify code is wrong")
}
query := repository.NewUserQuery()
query := repository.NewUserQuery(dal.GetDBClient())
userPO, err := query.GetUserDetail(ctx, repository.WithEmail(email))
if err != nil {
return nil, err
Expand Down Expand Up @@ -68,7 +69,7 @@ func ResetPassword(ctx context.Context, email string, password string, code stri
if storedCode != code {
return errors.New("verify code is wrong")
}
query := repository.NewUserQuery()
query := repository.NewUserQuery(dal.GetDBClient())
user, err := query.GetUserDetail(ctx, repository.WithEmail(email))
if err != nil {
return err
Expand Down
9 changes: 5 additions & 4 deletions service/course.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"

"jcourse_go/dal"
"jcourse_go/model/converter"
"jcourse_go/model/domain"
"jcourse_go/repository"
Expand All @@ -25,19 +26,19 @@ func GetCourseDetail(ctx context.Context, courseID int64) (*domain.Course, error
return nil, err
}

teacherQuery := repository.NewTeacherQuery()
teacherQuery := repository.NewTeacherQuery(dal.GetDBClient())
teacherPO, err := teacherQuery.GetTeacher(ctx, repository.WithID(coursePO.MainTeacherID))
if err != nil {
return nil, err
}

offeredCourseQuery := repository.NewOfferedCourseQuery()
offeredCourseQuery := repository.NewOfferedCourseQuery(dal.GetDBClient())
offeredCoursePOs, err := offeredCourseQuery.GetOfferedCourseList(ctx, repository.WithCourseID(courseID), repository.WithOrderBy("semester", false))
if err != nil {
return nil, err
}

reviewQuery := repository.NewReviewQuery()
reviewQuery := repository.NewReviewQuery(dal.GetDBClient())
infos, err := reviewQuery.GetCourseReviewInfo(ctx, []int64{courseID})
if err != nil {
return nil, err
Expand Down Expand Up @@ -93,7 +94,7 @@ func GetCourseList(ctx context.Context, filter domain.CourseListFilter) ([]domai
return nil, err
}

reviewQuery := repository.NewReviewQuery()
reviewQuery := repository.NewReviewQuery(dal.GetDBClient())
infos, err := reviewQuery.GetCourseReviewInfo(ctx, courseIDs)
if err != nil {
return nil, err
Expand Down
13 changes: 7 additions & 6 deletions service/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"

"jcourse_go/dal"
"jcourse_go/model/converter"
"jcourse_go/model/domain"
"jcourse_go/model/dto"
Expand Down Expand Up @@ -38,7 +39,7 @@ func buildReviewDBOptionFromFilter(query repository.IReviewQuery, filter domain.
}

func GetReviewList(ctx context.Context, filter domain.ReviewFilter) ([]domain.Review, error) {
reviewQuery := repository.NewReviewQuery()
reviewQuery := repository.NewReviewQuery(dal.GetDBClient())
opts := buildReviewDBOptionFromFilter(reviewQuery, filter)
reviewPOs, err := reviewQuery.GetReviewList(ctx, opts...)
if err != nil {
Expand Down Expand Up @@ -82,7 +83,7 @@ func GetReviewList(ctx context.Context, filter domain.ReviewFilter) ([]domain.Re
}

func GetReviewCount(ctx context.Context, filter domain.ReviewFilter) (int64, error) {
query := repository.NewReviewQuery()
query := repository.NewReviewQuery(dal.GetDBClient())
filter.Page, filter.PageSize = 0, 0
opts := buildReviewDBOptionFromFilter(query, filter)
return query.GetReviewCount(ctx, opts...)
Expand All @@ -92,7 +93,7 @@ func CreateReview(ctx context.Context, review dto.UpdateReviewDTO, user *domain.
if !validateReview(ctx, review, user) {
return 0, errors.New("validate review error")
}
query := repository.NewReviewQuery()
query := repository.NewReviewQuery(dal.GetDBClient())
reviewPO := converter.ConvertUpdateReviewDTOToPO(review, user.ID)
reviewID, err := query.CreateReview(ctx, reviewPO)
if err != nil {
Expand All @@ -108,7 +109,7 @@ func UpdateReview(ctx context.Context, review dto.UpdateReviewDTO, user *domain.
if !validateReview(ctx, review, user) {
return errors.New("validate review error")
}
query := repository.NewReviewQuery()
query := repository.NewReviewQuery(dal.GetDBClient())
reviewPO := converter.ConvertUpdateReviewDTOToPO(review, user.ID)
_, err := query.UpdateReview(ctx, reviewPO)
if err != nil {
Expand All @@ -118,7 +119,7 @@ func UpdateReview(ctx context.Context, review dto.UpdateReviewDTO, user *domain.
}

func DeleteReview(ctx context.Context, reviewID int64) error {
query := repository.NewReviewQuery()
query := repository.NewReviewQuery(dal.GetDBClient())
_, err := query.DeleteReview(ctx, repository.WithID(reviewID))
if err != nil {
return err
Expand All @@ -128,7 +129,7 @@ func DeleteReview(ctx context.Context, reviewID int64) error {

func validateReview(ctx context.Context, review dto.UpdateReviewDTO, user *domain.User) bool {
// 1. validate course and semester exists
offeredCourseQuery := repository.NewOfferedCourseQuery()
offeredCourseQuery := repository.NewOfferedCourseQuery(dal.GetDBClient())
offeredCourse, err := offeredCourseQuery.GetOfferedCourse(ctx, repository.WithCourseID(review.CourseID), repository.WithSemester(review.Semester))
if err != nil || offeredCourse == nil {
return false
Expand Down
15 changes: 8 additions & 7 deletions service/teacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"

"jcourse_go/dal"
"jcourse_go/model/converter"
"jcourse_go/model/domain"
"jcourse_go/repository"
Expand All @@ -13,15 +14,15 @@ func GetTeacherDetail(ctx context.Context, teacherID int64) (*domain.Teacher, er
if teacherID == 0 {
return nil, errors.New("training-plan id is 0")
}
teacherQuery := repository.NewTeacherQuery()
teacherQuery := repository.NewTeacherQuery(dal.GetDBClient())

teacherPO, err := teacherQuery.GetTeacher(ctx, repository.WithID(teacherID))
if err != nil {
return nil, err
}
teacher := converter.ConvertTeacherPOToDomain(teacherPO)

courseQuery := repository.NewOfferedCourseQuery()
courseQuery := repository.NewOfferedCourseQuery(dal.GetDBClient())
courses, err := courseQuery.GetOfferedCourseList(ctx, repository.WithMainTeacherID(teacherID))
if err != nil {
return nil, err
Expand Down Expand Up @@ -60,10 +61,10 @@ func buildTeacherDBOptionFromFilter(query repository.ITeacherQuery, filter domai
}

func SearchTeacherList(ctx context.Context, filter domain.TeacherListFilter) ([]domain.Teacher, error) {
teacherQuery := repository.NewTeacherQuery()
teacherQuery := repository.NewTeacherQuery(dal.GetDBClient())
t_opts := buildTeacherDBOptionFromFilter(teacherQuery, filter)

teacherCourseQuery := repository.NewOfferedCourseQuery()
teacherCourseQuery := repository.NewOfferedCourseQuery(dal.GetDBClient())
validTeacherIDs, err := teacherCourseQuery.GetMainTeacherIDsWithOfferedCourseIDs(ctx, filter.ContainCourseIDs)
if err != nil {
return nil, err
Expand All @@ -77,7 +78,7 @@ func SearchTeacherList(ctx context.Context, filter domain.TeacherListFilter) ([]

domainTeachers := make([]domain.Teacher, 0)
for _, t := range teachers {
q := repository.NewOfferedCourseQuery()
q := repository.NewOfferedCourseQuery(dal.GetDBClient())
offeredCoursePOs, err := q.GetOfferedCourseList(ctx, repository.WithMainTeacherID(int64(t.ID)))
if err != nil {
return nil, err
Expand All @@ -90,15 +91,15 @@ func SearchTeacherList(ctx context.Context, filter domain.TeacherListFilter) ([]
}

func GetTeacherCount(ctx context.Context, filter domain.TeacherListFilter) (int64, error) {
query := repository.NewTeacherQuery()
query := repository.NewTeacherQuery(dal.GetDBClient())
filter.Page, filter.PageSize = 0, 0
opts := buildTeacherDBOptionFromFilter(query, filter)
return query.GetTeacherCount(ctx, opts...)
}

func GetTeacherListByIDs(ctx context.Context, teacherIDs []int64) (map[int64]domain.Teacher, error) {

teacherQuery := repository.NewTeacherQuery()
teacherQuery := repository.NewTeacherQuery(dal.GetDBClient())
teachers, err := teacherQuery.GetTeacherList(ctx, repository.WithIDs(teacherIDs))
if err != nil {
return nil, err
Expand Down
13 changes: 7 additions & 6 deletions service/trainingplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"

"jcourse_go/dal"
"jcourse_go/model/converter"
"jcourse_go/model/domain"
"jcourse_go/repository"
Expand All @@ -13,17 +14,17 @@ func GetTrainingPlanDetail(ctx context.Context, trainingPlanID int64) (*domain.T
if trainingPlanID == 0 {
return nil, errors.New("training-plan id is 0")
}
trainingPlanQuery := repository.NewTrainingPlanQuery()
trainingPlanQuery := repository.NewTrainingPlanQuery(dal.GetDBClient())

trainingPlanPO, err := trainingPlanQuery.GetTrainingPlan(ctx, repository.WithID(trainingPlanID))
if err != nil {
return nil, err
}
trainingPlan := converter.ConvertTrainingPlanPOToDomain(*trainingPlanPO)

courseQuery := repository.NewTrainingPlanCourseQuery()
courseQuery := repository.NewTrainingPlanCourseQuery(dal.GetDBClient())
courses, err := courseQuery.GetCourseListOfTrainingPlan(ctx, trainingPlanID)
baseCourseQuery := repository.NewBaseCourseQuery()
baseCourseQuery := repository.NewBaseCourseQuery(dal.GetDBClient())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -65,15 +66,15 @@ func buildTrainingPlanCourseDBOptionFromFilter(query repository.ITrainingPlanCou
return opts
}
func GetTrainingPlanCount(ctx context.Context, filter domain.TrainingPlanFilter) int64 {
trainingPlanQuery := repository.NewTrainingPlanQuery()
trainingPlanQuery := repository.NewTrainingPlanQuery(dal.GetDBClient())
opts := buildTrainingPlanDBOptionFromFilter(trainingPlanQuery, filter)
return trainingPlanQuery.GetTrainingPlanCount(ctx, opts...)
}
func SearchTrainingPlanList(ctx context.Context, filter domain.TrainingPlanFilter) ([]domain.TrainingPlanDetail, error) {

trainingPlanQuery := repository.NewTrainingPlanQuery()
trainingPlanQuery := repository.NewTrainingPlanQuery(dal.GetDBClient())
tp_opts := buildTrainingPlanDBOptionFromFilter(trainingPlanQuery, filter)
trainingPlanCourseQuery := repository.NewTrainingPlanCourseQuery()
trainingPlanCourseQuery := repository.NewTrainingPlanCourseQuery(dal.GetDBClient())
if len(filter.ContainCourseIDs) != 0 {
tpc_opts := buildTrainingPlanCourseDBOptionFromFilter(trainingPlanCourseQuery, filter)
validTrainingPlanIDs, err := trainingPlanCourseQuery.GetTrainingPlanListIDs(ctx, tpc_opts...)
Expand Down
Loading

0 comments on commit f0e6e9c

Please sign in to comment.