Skip to content

Commit 090ec65

Browse files
committed
reflect comments
1 parent 6ab7905 commit 090ec65

File tree

8 files changed

+138
-59
lines changed

8 files changed

+138
-59
lines changed

dialect/mysql/select_test.go

+18-4
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,30 @@ func TestSelect(t *testing.T) {
3030
sm.Where(mysql.Quote("id").In(mysql.Arg(100, 200, 300))),
3131
),
3232
},
33-
"select with case": {
34-
ExpectedSQL: "SELECT id, name, CASE WHEN (`id` = '1') THEN 'A' ELSE 'B' END FROM users",
33+
"case with else": {
34+
ExpectedSQL: "SELECT id, name, (CASE WHEN (`id` = '1') THEN 'A' ELSE 'B' END) AS `C` FROM users",
3535
Query: mysql.Select(
3636
sm.Columns(
3737
"id",
3838
"name",
3939
mysql.Case().
4040
When(mysql.Quote("id").EQ(mysql.S("1")), mysql.S("A")).
41-
Else(mysql.S("B")),
42-
// as
41+
Else(mysql.S("B")).
42+
As("C"),
43+
),
44+
sm.From("users"),
45+
),
46+
},
47+
"case without else": {
48+
ExpectedSQL: "SELECT id, name, (CASE WHEN (`id` = '1') THEN 'A' END) AS `C` FROM users",
49+
Query: mysql.Select(
50+
sm.Columns(
51+
"id",
52+
"name",
53+
mysql.Case().
54+
When(mysql.Quote("id").EQ(mysql.S("1")), mysql.S("A")).
55+
End().
56+
As("C"),
4357
),
4458
sm.From("users"),
4559
),

dialect/mysql/starters.go

+4-28
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package mysql
22

33
import (
4-
"context"
5-
"io"
6-
74
"github.com/stephenafamo/bob"
85
"github.com/stephenafamo/bob/dialect/mysql/dialect"
96
"github.com/stephenafamo/bob/expr"
@@ -104,29 +101,8 @@ func Cast(exp bob.Expression, typname string) Expression {
104101
return bmod.Cast(exp, typname)
105102
}
106103

107-
type CaseChain[T bob.Expression] func() expr.Case[T]
108-
109-
func (c CaseChain[T]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
110-
return c().WriteSQL(ctx, w, d, start)
104+
// SQL: CASE WHEN a THEN b ELSE c END
105+
// Go: mysql.Case().When("a", "b").Else("c")
106+
func Case() expr.CaseChain[Expression, Expression] {
107+
return expr.NewCase[Expression, Expression]()
111108
}
112-
113-
func Case() CaseChain[Expression] {
114-
return CaseChain[Expression](func() expr.Case[Expression] { return expr.Case[Expression]{} })
115-
}
116-
117-
func (c CaseChain[T]) When(condition, then T) CaseChain[T] {
118-
cExpr := c()
119-
cExpr.Whens = append(cExpr.Whens, expr.When{Condition: condition, Then: then})
120-
return CaseChain[T](func() expr.Case[T] { return cExpr })
121-
}
122-
123-
func (c CaseChain[T]) Else(then T) Expression {
124-
cExpr := c()
125-
cExpr.Else = then
126-
var e dialect.Expression
127-
return e.New(cExpr)
128-
}
129-
130-
// func (c CaseChain[T]) As(alias string) T {
131-
// return bmod.X(c())
132-
// }

dialect/psql/select_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,34 @@ func TestSelect(t *testing.T) {
2929
sm.Where(psql.Quote("id").In(psql.Arg(100, 200, 300))),
3030
),
3131
},
32+
"case with else": {
33+
ExpectedSQL: `SELECT id, name, (CASE WHEN (id = '1') THEN 'A' ELSE 'B' END) AS "C" FROM users`,
34+
Query: psql.Select(
35+
sm.Columns(
36+
"id",
37+
"name",
38+
psql.Case().
39+
When(psql.Quote("id").EQ(psql.S("1")), psql.S("A")).
40+
Else(psql.S("B")).
41+
As("C"),
42+
),
43+
sm.From("users"),
44+
),
45+
},
46+
"case without else": {
47+
ExpectedSQL: `SELECT id, name, (CASE WHEN (id = '1') THEN 'A' END) AS "C" FROM users`,
48+
Query: psql.Select(
49+
sm.Columns(
50+
"id",
51+
"name",
52+
psql.Case().
53+
When(psql.Quote("id").EQ(psql.S("1")), psql.S("A")).
54+
End().
55+
As("C"),
56+
),
57+
sm.From("users"),
58+
),
59+
},
3260
"select distinct": {
3361
ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (id IN ($1, $2, $3))",
3462
ExpectedArgs: []any{100, 200, 300},

dialect/psql/starters.go

+6
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression {
100100
func Cast(exp bob.Expression, typname string) Expression {
101101
return bmod.Cast(exp, typname)
102102
}
103+
104+
// SQL: CASE WHEN a THEN b ELSE c END
105+
// Go: psql.Case().When("a", "b").Else("c")
106+
func Case() expr.CaseChain[Expression, Expression] {
107+
return expr.NewCase[Expression, Expression]()
108+
}

dialect/sqlite/select_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ func TestSelect(t *testing.T) {
3030
sm.Where(sqlite.Quote("id").In(sqlite.Arg(100, 200, 300))),
3131
),
3232
},
33+
"case with else": {
34+
ExpectedSQL: `SELECT id, name, (CASE WHEN ("id" = '1') THEN 'A' ELSE 'B' END) AS "C" FROM users`,
35+
Query: sqlite.Select(
36+
sm.Columns(
37+
"id",
38+
"name",
39+
sqlite.Case().
40+
When(sqlite.Quote("id").EQ(sqlite.S("1")), sqlite.S("A")).
41+
Else(sqlite.S("B")).
42+
As("C"),
43+
),
44+
sm.From("users"),
45+
),
46+
},
47+
"case without else": {
48+
ExpectedSQL: `SELECT id, name, (CASE WHEN ("id" = '1') THEN 'A' END) AS "C" FROM users`,
49+
Query: sqlite.Select(
50+
sm.Columns(
51+
"id",
52+
"name",
53+
sqlite.Case().
54+
When(sqlite.Quote("id").EQ(sqlite.S("1")), sqlite.S("A")).
55+
End().
56+
As("C"),
57+
),
58+
sm.From("users"),
59+
),
60+
},
3361
"select distinct": {
3462
ExpectedSQL: `SELECT DISTINCT id, name FROM users WHERE ("id" IN (?1, ?2, ?3))`,
3563
ExpectedArgs: []any{100, 200, 300},

dialect/sqlite/starters.go

+6
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression {
100100
func Cast(exp bob.Expression, typname string) Expression {
101101
return bmod.Cast(exp, typname)
102102
}
103+
104+
// SQL: CASE WHEN a THEN b ELSE c END
105+
// Go: sqlite.Case().When("a", "b").Else("c")
106+
func Case() expr.CaseChain[Expression, Expression] {
107+
return expr.NewCase[Expression, Expression]()
108+
}

expr/builder.go

-4
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,3 @@ func (e Builder[T, B]) Quote(aa ...string) T {
107107
func (e Builder[T, B]) Cast(exp bob.Expression, typname string) T {
108108
return X[T, B](Cast(exp, typname))
109109
}
110-
111-
func (e Builder[T, B]) Case(c Case[bob.Expression]) T {
112-
return X[T, B](c)
113-
}

expr/case.go

+48-23
Original file line numberDiff line numberDiff line change
@@ -8,51 +8,76 @@ import (
88
"github.com/stephenafamo/bob"
99
)
1010

11-
type Case[T bob.Expression] struct {
12-
Whens []When
13-
Else bob.Expression
14-
}
15-
16-
type When struct {
17-
Condition bob.Expression
18-
Then bob.Expression
19-
}
11+
type (
12+
caseExpr struct {
13+
whens []when
14+
elseExpr bob.Expression
15+
}
16+
when struct {
17+
condition bob.Expression
18+
then bob.Expression
19+
}
20+
)
2021

21-
func (c Case[T]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
22+
func (c caseExpr) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
2223
var args []any
2324

24-
if c.Else == nil && len(c.Whens) == 0 {
25+
if len(c.whens) == 0 {
2526
return nil, errors.New("case must have at least one when expression")
2627
}
2728

28-
io.WriteString(w, "CASE")
29-
for _, when := range c.Whens {
30-
io.WriteString(w, " WHEN ")
31-
whenArgs, err := when.Condition.WriteSQL(ctx, w, d, start+len(args))
29+
w.Write([]byte("CASE"))
30+
for _, when := range c.whens {
31+
w.Write([]byte(" WHEN "))
32+
whenArgs, err := when.condition.WriteSQL(ctx, w, d, start+len(args))
3233
if err != nil {
3334
return nil, err
3435
}
3536
args = append(args, whenArgs...)
3637

37-
io.WriteString(w, " THEN ")
38-
thenArgs, err := when.Then.WriteSQL(ctx, w, d, start+len(args))
38+
w.Write([]byte(" THEN "))
39+
thenArgs, err := when.then.WriteSQL(ctx, w, d, start+len(args))
3940
if err != nil {
4041
return nil, err
4142
}
4243
args = append(args, thenArgs...)
4344
}
4445

45-
if c.Else != nil {
46-
io.WriteString(w, " ELSE ")
47-
elseArgs, err := c.Else.WriteSQL(ctx, w, d, start+len(args))
46+
if c.elseExpr != nil {
47+
w.Write([]byte(" ELSE "))
48+
elseArgs, err := c.elseExpr.WriteSQL(ctx, w, d, start+len(args))
4849
if err != nil {
4950
return nil, err
5051
}
5152
args = append(args, elseArgs...)
5253
}
53-
io.WriteString(w, " END")
54-
55-
// as
54+
w.Write([]byte(" END"))
5655

5756
return args, nil
5857
}
58+
59+
type CaseChain[T bob.Expression, B builder[T]] func() caseExpr
60+
61+
func NewCase[T bob.Expression, B builder[T]]() CaseChain[T, B] {
62+
return CaseChain[T, B](func() caseExpr { return caseExpr{} })
63+
}
64+
65+
func (cc CaseChain[T, B]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
66+
return cc().WriteSQL(ctx, w, d, start)
67+
}
68+
69+
func (cc CaseChain[T, B]) When(condition, then bob.Expression) CaseChain[T, B] {
70+
c := cc()
71+
c.whens = append(c.whens, when{condition: condition, then: then})
72+
return CaseChain[T, B](func() caseExpr { return c })
73+
}
74+
75+
func (cc CaseChain[T, B]) Else(then bob.Expression) T {
76+
c := cc()
77+
c.elseExpr = then
78+
return X[T, B](c)
79+
}
80+
81+
func (cc CaseChain[T, B]) End() T {
82+
return X[T, B](cc())
83+
}

0 commit comments

Comments
 (0)