Skip to content

Commit 1d6c532

Browse files
authored
Merge pull request #294 from k4n4ry/issue-293
Include CASE WHEN expression starter
2 parents e949725 + 416ae43 commit 1d6c532

File tree

8 files changed

+186
-0
lines changed

8 files changed

+186
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5656
This is how `AfterSeleect/Insert/Update/DeleteHooks` hooks are now implemented.
5757
- Added `Type() QueryType` method to `bob.Query` to get the type of query it is. Available constants are `Unknown, Select, Insert, Update, Delete`.
5858
- Postgres and SQLite Update/Delete queries now refresh the models after the query is executed. This is enabled by the `RETURNING` clause, so it is not available in MySQL.
59+
- Added the `Case()` starter to all dialects to build `CASE` expressions. (thanks @k4n4ry)
5960

6061
### Changed
6162

dialect/mysql/select_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ func TestSelect(t *testing.T) {
3030
sm.Where(mysql.Quote("id").In(mysql.Arg(100, 200, 300))),
3131
),
3232
},
33+
"case with else": {
34+
ExpectedSQL: "SELECT id, name, (CASE WHEN (`id` = '1') THEN 'A' ELSE 'B' END) AS `C` FROM users",
35+
Query: mysql.Select(
36+
sm.Columns(
37+
"id",
38+
"name",
39+
mysql.Case().
40+
When(mysql.Quote("id").EQ(mysql.S("1")), mysql.S("A")).
41+
Else(mysql.S("B")).
42+
As("C"),
43+
),
44+
sm.From("users"),
45+
),
46+
},
47+
"case without else": {
48+
ExpectedSQL: "SELECT id, name, (CASE WHEN (`id` = '1') THEN 'A' END) AS `C` FROM users",
49+
Query: mysql.Select(
50+
sm.Columns(
51+
"id",
52+
"name",
53+
mysql.Case().
54+
When(mysql.Quote("id").EQ(mysql.S("1")), mysql.S("A")).
55+
End().
56+
As("C"),
57+
),
58+
sm.From("users"),
59+
),
60+
},
3361
"select distinct": {
3462
ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (`id` IN (?, ?, ?))",
3563
ExpectedArgs: []any{100, 200, 300},

dialect/mysql/starters.go

+6
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression {
100100
func Cast(exp bob.Expression, typname string) Expression {
101101
return bmod.Cast(exp, typname)
102102
}
103+
104+
// SQL: CASE WHEN a THEN b ELSE c END
105+
// Go: mysql.Case().When("a", "b").Else("c")
106+
func Case() expr.CaseChain[Expression, Expression] {
107+
return expr.NewCase[Expression, Expression]()
108+
}

dialect/psql/select_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,34 @@ func TestSelect(t *testing.T) {
2929
sm.Where(psql.Quote("id").In(psql.Arg(100, 200, 300))),
3030
),
3131
},
32+
"case with else": {
33+
ExpectedSQL: `SELECT id, name, (CASE WHEN (id = '1') THEN 'A' ELSE 'B' END) AS "C" FROM users`,
34+
Query: psql.Select(
35+
sm.Columns(
36+
"id",
37+
"name",
38+
psql.Case().
39+
When(psql.Quote("id").EQ(psql.S("1")), psql.S("A")).
40+
Else(psql.S("B")).
41+
As("C"),
42+
),
43+
sm.From("users"),
44+
),
45+
},
46+
"case without else": {
47+
ExpectedSQL: `SELECT id, name, (CASE WHEN (id = '1') THEN 'A' END) AS "C" FROM users`,
48+
Query: psql.Select(
49+
sm.Columns(
50+
"id",
51+
"name",
52+
psql.Case().
53+
When(psql.Quote("id").EQ(psql.S("1")), psql.S("A")).
54+
End().
55+
As("C"),
56+
),
57+
sm.From("users"),
58+
),
59+
},
3260
"select distinct": {
3361
ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (id IN ($1, $2, $3))",
3462
ExpectedArgs: []any{100, 200, 300},

dialect/psql/starters.go

+6
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression {
100100
func Cast(exp bob.Expression, typname string) Expression {
101101
return bmod.Cast(exp, typname)
102102
}
103+
104+
// SQL: CASE WHEN a THEN b ELSE c END
105+
// Go: psql.Case().When("a", "b").Else("c")
106+
func Case() expr.CaseChain[Expression, Expression] {
107+
return expr.NewCase[Expression, Expression]()
108+
}

dialect/sqlite/select_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ func TestSelect(t *testing.T) {
3030
sm.Where(sqlite.Quote("id").In(sqlite.Arg(100, 200, 300))),
3131
),
3232
},
33+
"case with else": {
34+
ExpectedSQL: `SELECT id, name, (CASE WHEN ("id" = '1') THEN 'A' ELSE 'B' END) AS "C" FROM users`,
35+
Query: sqlite.Select(
36+
sm.Columns(
37+
"id",
38+
"name",
39+
sqlite.Case().
40+
When(sqlite.Quote("id").EQ(sqlite.S("1")), sqlite.S("A")).
41+
Else(sqlite.S("B")).
42+
As("C"),
43+
),
44+
sm.From("users"),
45+
),
46+
},
47+
"case without else": {
48+
ExpectedSQL: `SELECT id, name, (CASE WHEN ("id" = '1') THEN 'A' END) AS "C" FROM users`,
49+
Query: sqlite.Select(
50+
sm.Columns(
51+
"id",
52+
"name",
53+
sqlite.Case().
54+
When(sqlite.Quote("id").EQ(sqlite.S("1")), sqlite.S("A")).
55+
End().
56+
As("C"),
57+
),
58+
sm.From("users"),
59+
),
60+
},
3361
"select distinct": {
3462
ExpectedSQL: `SELECT DISTINCT id, name FROM users WHERE ("id" IN (?1, ?2, ?3))`,
3563
ExpectedArgs: []any{100, 200, 300},

dialect/sqlite/starters.go

+6
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ func Raw(query string, args ...any) Expression {
100100
func Cast(exp bob.Expression, typname string) Expression {
101101
return bmod.Cast(exp, typname)
102102
}
103+
104+
// SQL: CASE WHEN a THEN b ELSE c END
105+
// Go: sqlite.Case().When("a", "b").Else("c")
106+
func Case() expr.CaseChain[Expression, Expression] {
107+
return expr.NewCase[Expression, Expression]()
108+
}

expr/case.go

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package expr
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
8+
"github.com/stephenafamo/bob"
9+
)
10+
11+
type (
12+
caseExpr struct {
13+
whens []when
14+
elseExpr bob.Expression
15+
}
16+
when struct {
17+
condition bob.Expression
18+
then bob.Expression
19+
}
20+
)
21+
22+
func (c caseExpr) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
23+
var args []any
24+
25+
if len(c.whens) == 0 {
26+
return nil, errors.New("case must have at least one when expression")
27+
}
28+
29+
w.Write([]byte("CASE"))
30+
for _, when := range c.whens {
31+
w.Write([]byte(" WHEN "))
32+
whenArgs, err := when.condition.WriteSQL(ctx, w, d, start+len(args))
33+
if err != nil {
34+
return nil, err
35+
}
36+
args = append(args, whenArgs...)
37+
38+
w.Write([]byte(" THEN "))
39+
thenArgs, err := when.then.WriteSQL(ctx, w, d, start+len(args))
40+
if err != nil {
41+
return nil, err
42+
}
43+
args = append(args, thenArgs...)
44+
}
45+
46+
if c.elseExpr != nil {
47+
w.Write([]byte(" ELSE "))
48+
elseArgs, err := c.elseExpr.WriteSQL(ctx, w, d, start+len(args))
49+
if err != nil {
50+
return nil, err
51+
}
52+
args = append(args, elseArgs...)
53+
}
54+
w.Write([]byte(" END"))
55+
56+
return args, nil
57+
}
58+
59+
type CaseChain[T bob.Expression, B builder[T]] func() caseExpr
60+
61+
func NewCase[T bob.Expression, B builder[T]]() CaseChain[T, B] {
62+
return CaseChain[T, B](func() caseExpr { return caseExpr{} })
63+
}
64+
65+
func (cc CaseChain[T, B]) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
66+
return cc().WriteSQL(ctx, w, d, start)
67+
}
68+
69+
func (cc CaseChain[T, B]) When(condition, then bob.Expression) CaseChain[T, B] {
70+
c := cc()
71+
c.whens = append(c.whens, when{condition: condition, then: then})
72+
return CaseChain[T, B](func() caseExpr { return c })
73+
}
74+
75+
func (cc CaseChain[T, B]) Else(then bob.Expression) T {
76+
c := cc()
77+
c.elseExpr = then
78+
return X[T, B](c)
79+
}
80+
81+
func (cc CaseChain[T, B]) End() T {
82+
return X[T, B](cc())
83+
}

0 commit comments

Comments
 (0)