Skip to content

Commit

Permalink
feat: bulk insert query for models (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
keroxp authored Dec 23, 2022
1 parent b68e95a commit e34fb2c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 15 deletions.
65 changes: 52 additions & 13 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,54 @@ func WhereOr(list ...q.Condition) q.Condition {
return q.ConditionOr(list...)
}

type ModelMetadata struct {
TableName string
AutoIncrementField *reflect.Value
Values q.KeyIterator
}

func QueryForInsert(modelPtr Model) (q.Query, *reflect.Value, error) {
m, err := AggregateModelMetadata(modelPtr)
if err != nil {
return nil, nil, err
}
return &q.Insert{
Into: m.TableName,
Values: m.Values.Map(),
}, m.AutoIncrementField, nil
}

func QueryForBulkInsert[T Model](modelPtrs ...T) (q.Query, error) {
if len(modelPtrs) == 0 {
return nil, xerrors.New("empty list")
}
var values [][]any
var head *ModelMetadata
for _, v := range modelPtrs {
if data, err := AggregateModelMetadata(v); err != nil {
return nil, err
} else {
if head == nil {
head = data
}
values = append(values, data.Values.Values())
}
}
return &q.InsertMany{
Into: head.TableName,
Columns: head.Values.Keys(),
Values: values,
}, nil
}

func AggregateModelMetadata(modelPtr Model) (*ModelMetadata, error) {
if modelPtr == nil {
return nil, nil, xerrors.Errorf("pointer is nil")
return nil, xerrors.Errorf("pointer is nil")
}
objValue := reflect.ValueOf(modelPtr)
objType := objValue.Type()
if objType.Kind() != reflect.Ptr || objType.Elem().Kind() != reflect.Struct {
return nil, nil, xerrors.Errorf("object must be pointer of struct")
return nil, xerrors.Errorf("object must be pointer of struct")
}
data := map[string]any{}
// *User -> User
Expand All @@ -43,11 +83,11 @@ func QueryForInsert(modelPtr Model) (q.Query, *reflect.Value, error) {
if t, ok := f.Tag.Lookup("exql"); ok {
tags, err := ParseTags(t)
if err != nil {
return nil, nil, err
return nil, err
}
colName, ok := tags["column"]
if !ok || colName == "" {
return nil, nil, xerrors.Errorf("column tag is not set")
return nil, xerrors.Errorf("column tag is not set")
}
exqlTagCount++
if _, primary := tags["primary"]; primary {
Expand All @@ -64,23 +104,22 @@ func QueryForInsert(modelPtr Model) (q.Query, *reflect.Value, error) {
}
}
if exqlTagCount == 0 {
return nil, nil, xerrors.Errorf("obj doesn't have exql tags in any fields")
return nil, xerrors.Errorf("obj doesn't have exql tags in any fields")
}

if len(primaryKeyFields) == 0 {
return nil, nil, xerrors.Errorf("table has no primary key")
return nil, xerrors.Errorf("table has no primary key")
}

tableName := modelPtr.TableName()
if tableName == "" {
return nil, nil, xerrors.Errorf("empty table name")
return nil, xerrors.Errorf("empty table name")
}
return &q.Insert{
Into: tableName,
Values: data,
},
autoIncrementField,
nil
return &ModelMetadata{
TableName: tableName,
AutoIncrementField: autoIncrementField,
Values: q.NewKeyIterator(data),
}, nil
}

func QueryForUpdateModel(
Expand Down
9 changes: 9 additions & 0 deletions query/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type KeyIterator interface {
Keys() []string
Values() []any
Size() int
Map() map[string]any
}

func NewKeyIterator(data map[string]any) KeyIterator {
Expand Down Expand Up @@ -53,6 +54,14 @@ func (k *keyIterator) Values() []any {
return k.values
}

func (k *keyIterator) Map() map[string]any {
res := map[string]any{}
for i := 0; i < k.Size(); i++ {
res[k.keys[i]] = k.values[i]
}
return res
}

func Placeholders(repeat int) string {
res := make([]string, repeat)
for i := 0; i < repeat; i++ {
Expand Down
6 changes: 4 additions & 2 deletions query/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
)

func TestKeyIteretor(t *testing.T) {
it := NewKeyIterator(map[string]any{
src := map[string]any{
"a": 1,
"b": 2,
"c": 3,
})
}
it := NewKeyIterator(src)
assert.Equal(t, it.Size(), 3)
assert.ElementsMatch(t, it.Keys(), []string{"a", "b", "c"})
assert.ElementsMatch(t, it.Values(), []any{1, 2, 3})
Expand All @@ -21,6 +22,7 @@ func TestKeyIteretor(t *testing.T) {
assert.Equal(t, it.Keys()[i], k)
assert.Equal(t, it.Values()[i], v)
}
assert.InDeltaMapValues(t, it.Map(), src, 0)
}

func TestSqlPraceholder(t *testing.T) {
Expand Down
37 changes: 37 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,42 @@ func TestQueryForInsert(t *testing.T) {
assert.Equal(t, exp, stmt)
assert.ElementsMatch(t, args, []any{user.FirstName, user.LastName})
})
}

func TestQueryForBulkInsert(t *testing.T) {
t.Run("basic", func(t *testing.T) {
q, err := exql.QueryForBulkInsert(
&model.Users{FirstName: null.StringFrom("user"), LastName: null.StringFrom("one")},
&model.Users{FirstName: null.StringFrom("user"), LastName: null.StringFrom("two")},
)
assert.NoError(t, err)
stmt, args, err := q.Query()
assert.NoError(t, err)
assert.Equal(t, "INSERT INTO `users` (`first_name`,`last_name`) VALUES (?,?),(?,?)", stmt)
assert.ElementsMatch(t, []any{
null.StringFrom("user"), null.StringFrom("one"),
null.StringFrom("user"), null.StringFrom("two"),
}, args)
})
t.Run("error if args empty", func(t *testing.T) {
q, err := exql.QueryForBulkInsert[model.Users]()
assert.Nil(t, q)
assert.EqualError(t, err, "empty list")
})
}

func TestAggregateModelMetadata(t *testing.T) {
t.Run("basic", func(t *testing.T) {
m, err := exql.AggregateModelMetadata(&model.Users{
FirstName: null.StringFrom("first"),
LastName: null.StringFrom("name"),
})
assert.NoError(t, err)
assert.Equal(t, "users", m.TableName)
assert.NotNil(t, m.AutoIncrementField)
assert.ElementsMatch(t, []string{"first_name", "last_name"}, m.Values.Keys())
assert.ElementsMatch(t, []any{null.StringFrom("first"), null.StringFrom("name")}, m.Values.Values())
})
assertInvalid := func(t *testing.T, m exql.Model, e string) {
s, f, err := exql.QueryForInsert(m)
assert.Nil(t, s)
Expand Down Expand Up @@ -89,6 +125,7 @@ func TestQueryForInsert(t *testing.T) {
assertInvalid(t, &sam, "table has no primary key")
})
}

func TestQueryForUpdateModel(t *testing.T) {
t.Run("basic", func(t *testing.T) {
user := &model.Users{}
Expand Down

0 comments on commit e34fb2c

Please sign in to comment.