Skip to content

Commit

Permalink
fix:
Browse files Browse the repository at this point in the history
  • Loading branch information
linyyyang committed May 26, 2023
1 parent 3dac1e1 commit c843642
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 53 deletions.
2 changes: 1 addition & 1 deletion constants/constants.go → gplus/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package constants
package gplus

const (
Comma = ","
Expand Down
3 changes: 1 addition & 2 deletions gplus/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"database/sql"
"reflect"

"github.com/goriller/gorm-plus/constants"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
Expand Down Expand Up @@ -341,7 +340,7 @@ func getPkColumnName[T any]() string {
}
}
if columnName == "" {
return constants.DefaultPrimaryName
return DefaultPrimaryName
}
return columnName
}
Expand Down
40 changes: 19 additions & 21 deletions gplus/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,84 +19,82 @@ package gplus

import (
"strings"

"github.com/goriller/gorm-plus/constants"
)

type Function struct {
funStr string
}

func (f *Function) As(asName any) string {
return f.funStr + " " + constants.As + " " + getColumnName(asName)
return f.funStr + " " + As + " " + getColumnName(asName)
}

func (f *Function) Eq(value int64) (string, int64) {
return buildFunStr(f.funStr, constants.Eq, value)
return buildFunStr(f.funStr, Eq, value)
}

func (f *Function) Ne(value int64) (string, int64) {
return buildFunStr(f.funStr, constants.Ne, value)
return buildFunStr(f.funStr, Ne, value)
}

func (f *Function) Gt(value int64) (string, int64) {
return buildFunStr(f.funStr, constants.Gt, value)
return buildFunStr(f.funStr, Gt, value)
}

func (f *Function) Ge(value int64) (string, int64) {
return buildFunStr(f.funStr, constants.Ge, value)
return buildFunStr(f.funStr, Ge, value)
}

func (f *Function) Lt(value int64) (string, int64) {
return buildFunStr(f.funStr, constants.Lt, value)
return buildFunStr(f.funStr, Lt, value)
}

func (f *Function) Le(value int64) (string, int64) {
return buildFunStr(f.funStr, constants.Le, value)
return buildFunStr(f.funStr, Le, value)
}

func (f *Function) In(values ...any) (string, []any) {
// 构建占位符
placeholder := buildPlaceholder(values)
return f.funStr + " " + constants.In + placeholder.String(), values
return f.funStr + " " + In + placeholder.String(), values
}

func (f *Function) NotIn(values ...any) (string, []any) {
// 构建占位符
placeholder := buildPlaceholder(values)
return f.funStr + " " + constants.Not + " " + constants.In + placeholder.String(), values
return f.funStr + " " + Not + " " + In + placeholder.String(), values
}

func (f *Function) Between(start int64, end int64) (string, int64, int64) {
return f.funStr + " " + constants.Between + " ? and ?", start, end
return f.funStr + " " + Between + " ? and ?", start, end
}

func (f *Function) NotBetween(start int64, end int64) (string, int64, int64) {
return f.funStr + " " + constants.Not + " " + constants.Between + " ? and ?", start, end
return f.funStr + " " + Not + " " + Between + " ? and ?", start, end
}

func Sum(columnName any) *Function {
return &Function{funStr: addBracket(constants.SUM, getColumnName(columnName))}
return &Function{funStr: addBracket(SUM, getColumnName(columnName))}
}

func Avg(columnName any) *Function {
return &Function{funStr: addBracket(constants.AVG, getColumnName(columnName))}
return &Function{funStr: addBracket(AVG, getColumnName(columnName))}
}

func Max(columnName any) *Function {
return &Function{funStr: addBracket(constants.MAX, getColumnName(columnName))}
return &Function{funStr: addBracket(MAX, getColumnName(columnName))}
}

func Min(columnName any) *Function {
return &Function{funStr: addBracket(constants.MIN, getColumnName(columnName))}
return &Function{funStr: addBracket(MIN, getColumnName(columnName))}
}

func Count(columnName any) *Function {
return &Function{funStr: addBracket(constants.COUNT, getColumnName(columnName))}
return &Function{funStr: addBracket(COUNT, getColumnName(columnName))}
}

func addBracket(function string, columnNameStr string) string {
return function + constants.LeftBracket + columnNameStr + constants.RightBracket
return function + LeftBracket + columnNameStr + RightBracket
}

func buildFunStr(funcStr string, typeStr string, value int64) (string, int64) {
Expand All @@ -105,11 +103,11 @@ func buildFunStr(funcStr string, typeStr string, value int64) (string, int64) {

func buildPlaceholder(values []any) strings.Builder {
var placeholder strings.Builder
placeholder.WriteString(constants.LeftBracket)
placeholder.WriteString(LeftBracket)
for i := 0; i < len(values); i++ {
if i == len(values)-1 {
placeholder.WriteString("?")
placeholder.WriteString(constants.RightBracket)
placeholder.WriteString(RightBracket)
break
}
placeholder.WriteString("?")
Expand Down
2 changes: 1 addition & 1 deletion constants/keyword.go → gplus/keyword.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package constants
package gplus

const (
And = "AND"
Expand Down
54 changes: 26 additions & 28 deletions gplus/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"fmt"
"reflect"
"strings"

"github.com/goriller/gorm-plus/constants"
)

type QueryCond[T any] struct {
Expand Down Expand Up @@ -87,65 +85,65 @@ func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) {

// Eq 等于 =
func (q *QueryCond[T]) Eq(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Eq)
q.addCond(column, val, Eq)
return q
}

// Ne 不等于 !=
func (q *QueryCond[T]) Ne(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Ne)
q.addCond(column, val, Ne)
return q
}

// Gt 大于 >
func (q *QueryCond[T]) Gt(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Gt)
q.addCond(column, val, Gt)
return q
}

// Ge 大于等于 >=
func (q *QueryCond[T]) Ge(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Ge)
q.addCond(column, val, Ge)
return q
}

// Lt 小于 <
func (q *QueryCond[T]) Lt(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Lt)
q.addCond(column, val, Lt)
return q
}

// Le 小于等于 <=
func (q *QueryCond[T]) Le(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Le)
q.addCond(column, val, Le)
return q
}

// Like 模糊 LIKE '%值%'
func (q *QueryCond[T]) Like(column any, val any) *QueryCond[T] {
s := fmt.Sprintf("%v", val)
q.addCond(column, "%"+s+"%", constants.Like)
q.addCond(column, "%"+s+"%", Like)
return q
}

// NotLike 非模糊 NOT LIKE '%值%'
func (q *QueryCond[T]) NotLike(column any, val any) *QueryCond[T] {
s := fmt.Sprintf("%v", val)
q.addCond(column, "%"+s+"%", constants.Not+" "+constants.Like)
q.addCond(column, "%"+s+"%", Not+" "+Like)
return q
}

// LikeLeft 左模糊 LIKE '%值'
func (q *QueryCond[T]) LikeLeft(column any, val any) *QueryCond[T] {
s := fmt.Sprintf("%v", val)
q.addCond(column, "%"+s, constants.Like)
q.addCond(column, "%"+s, Like)
return q
}

// LikeRight 右模糊 LIKE '值%'
func (q *QueryCond[T]) LikeRight(column any, val any) *QueryCond[T] {
s := fmt.Sprintf("%v", val)
q.addCond(column, s+"%", constants.Like)
q.addCond(column, s+"%", Like)
return q
}

Expand All @@ -169,21 +167,21 @@ func (q *QueryCond[T]) IsNotNull(column any) *QueryCond[T] {

// In 字段 IN (值1, 值2, ...)
func (q *QueryCond[T]) In(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.In)
q.addCond(column, val, In)
return q
}

// NotIn 字段 NOT IN (值1, 值2, ...)
func (q *QueryCond[T]) NotIn(column any, val any) *QueryCond[T] {
q.addCond(column, val, constants.Not+" "+constants.In)
q.addCond(column, val, Not+" "+In)
return q
}

// Between BETWEEN 值1 AND 值2
func (q *QueryCond[T]) Between(column any, start, end any) *QueryCond[T] {
columnName := getColumnName(column)
q.buildAndIfNeed()
cond := fmt.Sprintf("%s %s ? and ? ", columnName, constants.Between)
cond := fmt.Sprintf("%s %s ? and ? ", columnName, Between)
q.queryBuilder.WriteString(cond)
q.queryArgs = append(q.queryArgs, start, end)
return q
Expand All @@ -193,7 +191,7 @@ func (q *QueryCond[T]) Between(column any, start, end any) *QueryCond[T] {
func (q *QueryCond[T]) NotBetween(column any, start, end any) *QueryCond[T] {
columnName := getColumnName(column)
q.buildAndIfNeed()
cond := fmt.Sprintf("%s %s %s ? and ? ", columnName, constants.Not, constants.Between)
cond := fmt.Sprintf("%s %s %s ? and ? ", columnName, Not, Between)
q.queryBuilder.WriteString(cond)
q.queryArgs = append(q.queryArgs, start, end)
return q
Expand All @@ -214,13 +212,13 @@ func (q *QueryCond[T]) And(fn ...func(q *QueryCond[T])) *QueryCond[T] {
if len(fn) > 0 {
nestQuery := &QueryCond[T]{}
fn[0](nestQuery)
q.andNestBuilder.WriteString(constants.And + " " + constants.LeftBracket + nestQuery.queryBuilder.String() + constants.RightBracket + " ")
q.andNestBuilder.WriteString(And + " " + LeftBracket + nestQuery.queryBuilder.String() + RightBracket + " ")
q.andNestArgs = append(q.andNestArgs, nestQuery.queryArgs...)
return q
}
q.queryBuilder.WriteString(constants.And)
q.queryBuilder.WriteString(And)
q.queryBuilder.WriteString(" ")
q.lastCond = constants.And
q.lastCond = And
return q
}

Expand All @@ -229,13 +227,13 @@ func (q *QueryCond[T]) Or(fn ...func(q *QueryCond[T])) *QueryCond[T] {
if len(fn) > 0 {
nestQuery := &QueryCond[T]{}
fn[0](nestQuery)
q.orNestBuilder.WriteString(constants.Or + " " + constants.LeftBracket + nestQuery.queryBuilder.String() + constants.RightBracket + " ")
q.orNestBuilder.WriteString(Or + " " + LeftBracket + nestQuery.queryBuilder.String() + RightBracket + " ")
q.orNestArgs = append(q.orNestArgs, nestQuery.queryArgs...)
return q
}
q.queryBuilder.WriteString(constants.Or)
q.queryBuilder.WriteString(Or)
q.queryBuilder.WriteString(" ")
q.lastCond = constants.Or
q.lastCond = Or
return q
}

Expand All @@ -255,7 +253,7 @@ func (q *QueryCond[T]) OrderByDesc(columns ...any) *QueryCond[T] {
columnName := getColumnName(v)
columnNames = append(columnNames, columnName)
}
q.buildOrder(constants.Desc, columnNames...)
q.buildOrder(Desc, columnNames...)
return q
}

Expand All @@ -266,7 +264,7 @@ func (q *QueryCond[T]) OrderByAsc(columns ...any) *QueryCond[T] {
columnName := getColumnName(v)
columnNames = append(columnNames, columnName)
}
q.buildOrder(constants.Asc, columnNames...)
q.buildOrder(Asc, columnNames...)
return q
}

Expand All @@ -275,7 +273,7 @@ func (q *QueryCond[T]) Group(columns ...any) *QueryCond[T] {
for _, v := range columns {
columnName := getColumnName(v)
if q.groupBuilder.Len() > 0 {
q.groupBuilder.WriteString(constants.Comma)
q.groupBuilder.WriteString(Comma)
}
q.groupBuilder.WriteString(columnName)
}
Expand Down Expand Up @@ -329,16 +327,16 @@ func (q *QueryCond[T]) addCond(column any, val any, condType string) {
}

func (q *QueryCond[T]) buildAndIfNeed() {
if q.lastCond != constants.And && q.lastCond != constants.Or && q.queryBuilder.Len() > 0 {
q.queryBuilder.WriteString(constants.And)
if q.lastCond != And && q.lastCond != Or && q.queryBuilder.Len() > 0 {
q.queryBuilder.WriteString(And)
q.queryBuilder.WriteString(" ")
}
}

func (q *QueryCond[T]) buildOrder(orderType string, columns ...string) {
for _, v := range columns {
if q.orderBuilder.Len() > 0 {
q.orderBuilder.WriteString(constants.Comma)
q.orderBuilder.WriteString(Comma)
}
q.orderBuilder.WriteString(v)
q.orderBuilder.WriteString(" ")
Expand Down

0 comments on commit c843642

Please sign in to comment.