forked from kataras/iris
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add a new x/sqlx sub-package and example
Showing
8 changed files
with
595 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"encoding/json" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/kataras/iris/v12" | ||
"github.com/kataras/iris/v12/x/errors" | ||
"github.com/kataras/iris/v12/x/sqlx" | ||
|
||
_ "github.com/lib/pq" | ||
) | ||
|
||
const ( | ||
host = "localhost" | ||
port = 5432 | ||
user = "postgres" | ||
password = "admin!123" | ||
dbname = "test" | ||
) | ||
|
||
func main() { | ||
app := iris.New() | ||
|
||
db := mustConnectDB() | ||
mustCreateExtensions(context.Background(), db) | ||
mustCreateTables(context.Background(), db) | ||
|
||
app.Post("/", insert(db)) | ||
app.Get("/", list(db)) | ||
app.Get("/{event_id:uuid}", getByID(db)) | ||
|
||
/* | ||
curl --location --request POST 'http://localhost:8080' \ | ||
--header 'Content-Type: application/json' \ | ||
--data-raw '{ | ||
"name": "second_test_event", | ||
"data": { | ||
"key": "value", | ||
"year": 2022 | ||
} | ||
}' | ||
curl --location --request GET 'http://localhost:8080' | ||
curl --location --request GET 'http://localhost:8080/4fc0363f-1d1f-4a43-8608-5ed266485645' | ||
*/ | ||
app.Listen(":8080") | ||
} | ||
|
||
func mustConnectDB() *sql.DB { | ||
connString := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", | ||
host, port, user, password, dbname) | ||
db, err := sql.Open("postgres", connString) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
err = db.Ping() | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
return db | ||
} | ||
|
||
func mustCreateExtensions(ctx context.Context, db *sql.DB) { | ||
query := `CREATE EXTENSION IF NOT EXISTS pgcrypto;` | ||
_, err := db.ExecContext(ctx, query) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
func mustCreateTables(ctx context.Context, db *sql.DB) { | ||
query := `CREATE TABLE IF NOT EXISTS "events" ( | ||
"id" uuid PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), | ||
"created_at" timestamp(6) DEFAULT now(), | ||
"name" text COLLATE "pg_catalog"."default", | ||
"data" jsonb | ||
);` | ||
|
||
_, err := db.ExecContext(ctx, query) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
sqlx.Register("events", Event{}) | ||
} | ||
|
||
type Event struct { | ||
ID string `json:"id"` | ||
CreatedAt time.Time `json:"created_at"` | ||
Name string `json:"name"` | ||
Data json.RawMessage `json:"data"` | ||
|
||
Presenter string `db:"-" json:"-"` | ||
} | ||
|
||
func insertEvent(ctx context.Context, db *sql.DB, evt Event) (id string, err error) { | ||
query := `INSERT INTO events(name,data) VALUES($1,$2) RETURNING id;` | ||
err = db.QueryRowContext(ctx, query, evt.Name, evt.Data).Scan(&id) | ||
return | ||
} | ||
|
||
func listEvents(ctx context.Context, db *sql.DB) ([]Event, error) { | ||
list := make([]Event, 0) | ||
query := `SELECT * FROM events ORDER BY created_at;` | ||
rows, err := db.QueryContext(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
// Not required. See sqlx.DefaultSchema.AutoCloseRows field. | ||
// defer rows.Close() | ||
|
||
if err = sqlx.Bind(&list, rows); err != nil { | ||
return nil, err | ||
} | ||
|
||
return list, nil | ||
} | ||
|
||
func getEvent(ctx context.Context, db *sql.DB, id string) (Event, error) { | ||
query := `SELECT * FROM events WHERE id = $1 LIMIT 1;` | ||
rows, err := db.QueryContext(ctx, query, id) | ||
if err != nil { | ||
return Event{}, err | ||
} | ||
|
||
var evt Event | ||
err = sqlx.Bind(&evt, rows) | ||
|
||
return evt, err | ||
} | ||
|
||
func insert(db *sql.DB) iris.Handler { | ||
return func(ctx iris.Context) { | ||
var evt Event | ||
if err := ctx.ReadJSON(&evt); err != nil { | ||
errors.InvalidArgument.Details(ctx, "unable to read body", err.Error()) | ||
return | ||
} | ||
|
||
id, err := insertEvent(ctx, db, evt) | ||
if err != nil { | ||
errors.Internal.LogErr(ctx, err) | ||
return | ||
} | ||
|
||
ctx.JSON(iris.Map{"id": id}) | ||
} | ||
} | ||
|
||
func list(db *sql.DB) iris.Handler { | ||
return func(ctx iris.Context) { | ||
events, err := listEvents(ctx, db) | ||
if err != nil { | ||
errors.Internal.LogErr(ctx, err) | ||
return | ||
} | ||
|
||
ctx.JSON(events) | ||
} | ||
} | ||
|
||
func getByID(db *sql.DB) iris.Handler { | ||
return func(ctx iris.Context) { | ||
eventID := ctx.Params().Get("event_id") | ||
|
||
evt, err := getEvent(ctx, db, eventID) | ||
if err != nil { | ||
errors.Internal.LogErr(ctx, err) | ||
return | ||
} | ||
|
||
ctx.JSON(evt) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
package sqlx | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
"unsafe" | ||
|
||
"github.com/kataras/iris/v12/x/reflex" | ||
) | ||
|
||
type ( | ||
// Schema holds the row definitions. | ||
Schema struct { | ||
Name string | ||
Rows map[reflect.Type]*Row | ||
ColumnNameFunc ColumnNameFunc | ||
AutoCloseRows bool | ||
} | ||
|
||
// Row holds the column definitions and the struct type & name. | ||
Row struct { | ||
Schema string // e.g. public | ||
Name string // e.g. users. Must set to a custom one if the select query contains AS names. | ||
StructType reflect.Type | ||
Columns map[string]*Column // e.g. "id":{"id", 0, [0]} | ||
} | ||
|
||
// Column holds the database column name and other properties extracted by a struct's field. | ||
Column struct { | ||
Name string | ||
Index int | ||
FieldIndex []int | ||
} | ||
) | ||
|
||
// NewSchema returns a new Schema. Use its Register() method to cache | ||
// a structure value so Bind() can fill all struct's fields based on a query. | ||
func NewSchema() *Schema { | ||
return &Schema{ | ||
Name: "public", | ||
Rows: make(map[reflect.Type]*Row), | ||
ColumnNameFunc: snakeCase, | ||
AutoCloseRows: true, | ||
} | ||
} | ||
|
||
// DefaultSchema initializes a common Schema. | ||
var DefaultSchema = NewSchema() | ||
|
||
// Register caches a struct value to the default schema. | ||
func Register(tableName string, value interface{}) *Schema { | ||
return DefaultSchema.Register(tableName, value) | ||
} | ||
|
||
// Bind sets "dst" to the result of "src" and reports any errors. | ||
func Bind(dst interface{}, src *sql.Rows) error { | ||
return DefaultSchema.Bind(dst, src) | ||
} | ||
|
||
// Register caches a struct value to the schema. | ||
func (s *Schema) Register(tableName string, value interface{}) *Schema { | ||
typ := reflect.TypeOf(value) | ||
for typ.Kind() == reflect.Ptr { | ||
typ = typ.Elem() | ||
} | ||
|
||
if tableName == "" { | ||
// convert to a human name, e.g. sqlx.Food -> food. | ||
typeName := typ.String() | ||
if idx := strings.LastIndexByte(typeName, '.'); idx > 0 && len(typeName) > idx { | ||
typeName = typeName[idx+1:] | ||
} | ||
tableName = snakeCase(typeName) | ||
} | ||
|
||
columns, err := convertStructToColumns(typ, s.ColumnNameFunc) | ||
if err != nil { | ||
panic(fmt.Sprintf("sqlx: register: %q: %s", reflect.TypeOf(value).String(), err.Error())) | ||
} | ||
|
||
s.Rows[typ] = &Row{ | ||
Schema: s.Name, | ||
Name: tableName, | ||
StructType: typ, | ||
Columns: columns, | ||
} | ||
|
||
return s | ||
} | ||
|
||
// Bind sets "dst" to the result of "src" and reports any errors. | ||
func (s *Schema) Bind(dst interface{}, src *sql.Rows) error { | ||
typ := reflect.TypeOf(dst) | ||
if typ.Kind() != reflect.Ptr { | ||
return fmt.Errorf("sqlx: bind: destination not a pointer") | ||
} | ||
|
||
typ = typ.Elem() | ||
|
||
originalKind := typ.Kind() | ||
if typ.Kind() == reflect.Slice { | ||
typ = typ.Elem() | ||
} | ||
|
||
r, ok := s.Rows[typ] | ||
if !ok { | ||
return fmt.Errorf("sqlx: bind: unregistered type: %q", typ.String()) | ||
} | ||
|
||
columnTypes, err := src.ColumnTypes() | ||
if err != nil { | ||
return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err) | ||
} | ||
|
||
if expected, got := len(r.Columns), len(columnTypes); expected != got { | ||
return fmt.Errorf("sqlx: bind: table: %q: unexpected number of result columns: %d: expected: %d", r.Name, got, expected) | ||
} | ||
|
||
val := reflex.IndirectValue(reflect.ValueOf(dst)) | ||
if s.AutoCloseRows { | ||
defer src.Close() | ||
} | ||
|
||
switch originalKind { | ||
case reflect.Struct: | ||
if src.Next() { | ||
if err = r.bindSingle(typ, val, columnTypes, src); err != nil { | ||
return err | ||
} | ||
} else { | ||
return sql.ErrNoRows | ||
} | ||
|
||
return src.Err() | ||
case reflect.Slice: | ||
for src.Next() { | ||
elem := reflect.New(typ).Elem() | ||
if err = r.bindSingle(typ, elem, columnTypes, src); err != nil { | ||
return err | ||
} | ||
|
||
val = reflect.Append(val, elem) | ||
} | ||
|
||
if err = src.Err(); err != nil { | ||
return err | ||
} | ||
|
||
reflect.ValueOf(dst).Elem().Set(val) | ||
return nil | ||
default: | ||
return fmt.Errorf("sqlx: bind: table: %q: unexpected destination kind: %q", r.Name, typ.Kind().String()) | ||
} | ||
} | ||
|
||
func (r *Row) bindSingle(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType, scanner interface{ Scan(...interface{}) error }) error { | ||
fieldPtrs, err := r.lookupStructFieldPtrs(typ, val, columnTypes) | ||
if err != nil { | ||
return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err) | ||
} | ||
|
||
return scanner.Scan(fieldPtrs...) | ||
} | ||
|
||
func (r *Row) lookupStructFieldPtrs(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType) ([]interface{}, error) { | ||
fieldPtrs := make([]interface{}, 0, len(columnTypes)) | ||
|
||
for _, columnType := range columnTypes { | ||
columnName := columnType.Name() | ||
tableColumn, ok := r.Columns[columnName] | ||
if !ok { | ||
continue | ||
} | ||
|
||
tableColumnField, err := val.FieldByIndexErr(tableColumn.FieldIndex) | ||
if err != nil { | ||
return nil, fmt.Errorf("column: %q: %w", tableColumn.Name, err) | ||
} | ||
|
||
tableColumnFieldType := tableColumnField.Type() | ||
|
||
fieldPtr := reflect.NewAt(tableColumnFieldType, unsafe.Pointer(tableColumnField.UnsafeAddr())).Elem().Addr().Interface() | ||
fieldPtrs = append(fieldPtrs, fieldPtr) | ||
} | ||
|
||
return fieldPtrs, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package sqlx | ||
|
||
/* | ||
import ( | ||
"reflect" | ||
"testing" | ||
sqlmock "github.com/DATA-DOG/go-sqlmock" | ||
) | ||
type food struct { | ||
ID string | ||
Name string | ||
Presenter bool `db:"-"` | ||
} | ||
func TestTableBind(t *testing.T) { | ||
Register("foods", food{}) | ||
db, mock, err := sqlmock.New() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
mock.ExpectQuery("SELECT .* FROM foods WHERE id = ?"). | ||
WithArgs("42"). | ||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). | ||
AddRow("42", "banana"). | ||
AddRow("43", "broccoli")) | ||
rows, err := db.Query("SELECT .* FROM foods WHERE id = ? LIMIT 1", "42") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
var f food | ||
err = Bind(&f, rows) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
expectedSingle := food{"42", "banana", false} | ||
if !reflect.DeepEqual(f, expectedSingle) { | ||
t.Fatalf("expected value: %#+v but got: %#+v", expectedSingle, f) | ||
} | ||
mock.ExpectQuery("SELECT .* FROM foods"). | ||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). | ||
AddRow("42", "banana"). | ||
AddRow("43", "broccoli"). | ||
AddRow("44", "chicken")) | ||
rows, err = db.Query("SELECT .* FROM foods") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
var foods []food | ||
err = Bind(&foods, rows) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
expectedMany := []food{ | ||
{"42", "banana", false}, | ||
{"43", "broccoli", false}, | ||
{"44", "chicken", false}, | ||
} | ||
for i := range foods { | ||
if !reflect.DeepEqual(foods[i], expectedMany[i]) { | ||
t.Fatalf("[%d] expected: %#+v but got: %#+v", i, expectedMany[i], foods[i]) | ||
} | ||
} | ||
} | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package sqlx | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
|
||
"github.com/kataras/iris/v12/x/reflex" | ||
) | ||
|
||
// DefaultTag is the default struct field tag. | ||
var DefaultTag = "db" | ||
|
||
type ColumnNameFunc = func(string) string | ||
|
||
func convertStructToColumns(typ reflect.Type, nameFunc ColumnNameFunc) (map[string]*Column, error) { | ||
if kind := typ.Kind(); kind != reflect.Struct { | ||
return nil, fmt.Errorf("convert struct: invalid type: expected a struct value but got: %q", kind.String()) | ||
} | ||
|
||
// Retrieve only fields valid for database. | ||
fields := reflex.LookupFields(typ, "") | ||
|
||
columns := make(map[string]*Column, len(fields)) | ||
for i, field := range fields { | ||
column, ok, err := convertStructFieldToColumn(field, DefaultTag, nameFunc) | ||
if !ok { | ||
continue | ||
} | ||
|
||
if err != nil { | ||
return nil, fmt.Errorf("convert struct: field name: %q: %w", field.Name, err) | ||
} | ||
|
||
column.Index = i | ||
columns[column.Name] = column | ||
} | ||
|
||
return columns, nil | ||
} | ||
|
||
func convertStructFieldToColumn(field reflect.StructField, optionalTag string, nameFunc ColumnNameFunc) (*Column, bool, error) { | ||
c := &Column{ | ||
Name: nameFunc(field.Name), | ||
FieldIndex: field.Index, | ||
} | ||
|
||
fieldTag, ok := field.Tag.Lookup(optionalTag) | ||
if ok { | ||
if fieldTag == "-" { | ||
return nil, false, nil | ||
} | ||
|
||
if err := parseOptions(fieldTag, c); err != nil { | ||
return nil, false, err | ||
} | ||
} | ||
|
||
return c, true, nil | ||
} | ||
|
||
func parseOptions(fieldTag string, c *Column) error { | ||
options := strings.Split(fieldTag, ",") | ||
for _, opt := range options { | ||
if opt == "" { | ||
continue // skip empty. | ||
} | ||
|
||
var key, value string | ||
|
||
kv := strings.Split(opt, "=") // When more options come to play. | ||
switch len(kv) { | ||
case 2: | ||
key = kv[0] | ||
value = kv[1] | ||
case 1: | ||
c.Name = kv[0] | ||
return nil | ||
default: | ||
return fmt.Errorf("option: %s: expected key value separated by '='", opt) | ||
} | ||
|
||
switch key { | ||
case "name": | ||
c.Name = value | ||
default: | ||
return fmt.Errorf("unexpected tag option: %s", key) | ||
} | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package sqlx | ||
|
||
import "strings" | ||
|
||
// snakeCase converts a given string to a friendly snake case, e.g. | ||
// - userId to user_id | ||
// - ID to id | ||
// - ProviderAPIKey to provider_api_key | ||
// - Option to option | ||
func snakeCase(camel string) string { | ||
var ( | ||
b strings.Builder | ||
prevWasUpper bool | ||
) | ||
|
||
for i, c := range camel { | ||
if isUppercase(c) { // it's upper. | ||
if b.Len() > 0 && !prevWasUpper { // it's not the first and the previous was not uppercased too (e.g "ID"). | ||
b.WriteRune('_') | ||
} else { // check for XxxAPIKey, it should be written as xxx_api_key. | ||
next := i + 1 | ||
if next > 1 && len(camel)-1 > next { | ||
if !isUppercase(rune(camel[next])) { | ||
b.WriteRune('_') | ||
} | ||
} | ||
} | ||
|
||
b.WriteRune(c - 'A' + 'a') // write its lowercase version. | ||
prevWasUpper = true | ||
} else { | ||
b.WriteRune(c) // write it as it is, it's already lowercased. | ||
prevWasUpper = false | ||
} | ||
} | ||
|
||
return b.String() | ||
} | ||
|
||
func isUppercase(c rune) bool { | ||
return 'A' <= c && c <= 'Z' | ||
} |