Skip to content

Commit

Permalink
add OnDuplicate/OnDuplicateEx feature for package gdb
Browse files Browse the repository at this point in the history
  • Loading branch information
gqcn committed Jun 16, 2021
1 parent e4b0de0 commit d450de8
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 193 deletions.
2 changes: 1 addition & 1 deletion container/gset/gset_str_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type StrSet struct {
data map[string]struct{}
}

// New create and returns a new set, which contains un-repeated items.
// NewStrSet create and returns a new set, which contains un-repeated items.
// The parameter <safe> is used to specify whether using set in concurrent-safety,
// which is false in default.
func NewStrSet(safe ...bool) *StrSet {
Expand Down
10 changes: 9 additions & 1 deletion database/gdb/gdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type DB interface {
// ===========================================================================

DoGetAll(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) // See Core.DoGetAll.
DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch int) (result sql.Result, err error) // See Core.DoInsert.
DoInsert(ctx context.Context, link Link, table string, data List, option DoInsertOption) (result sql.Result, err error) // See Core.DoInsert.
DoUpdate(ctx context.Context, link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoUpdate.
DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete.
DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery.
Expand Down Expand Up @@ -214,6 +214,14 @@ type Sql struct {
IsTransaction bool // IsTransaction marks whether this sql is executed in transaction.
}

// DoInsertOption is the input struct for function DoInsert.
type DoInsertOption struct {
OnDuplicateStr string
OnDuplicateMap map[string]interface{}
InsertOption int // Insert operation.
BatchCount int // Batch count for batch inserting.
}

// TableField is the struct for table field.
type TableField struct {
Index int // For ordering purpose as map is unordered.
Expand Down
164 changes: 64 additions & 100 deletions database/gdb/gdb_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,133 +367,51 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e
// 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
// 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one;
// 3: ignore: if there's unique/primary key in the data, it ignores the inserting;
func (c *Core) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch int) (result sql.Result, err error) {
table = c.QuotePrefixTableName(table)
func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) {
var (
keys []string // Field names.
values []string // Value holder string array, like: (?,?,?)
params []interface{} // Values that will be committed to underlying database driver.
listMap List // The data list that passed from caller.
keys []string // Field names.
values []string // Value holder string array, like: (?,?,?)
params []interface{} // Values that will be committed to underlying database driver.
onDuplicateStr string // onDuplicateStr is used in "ON DUPLICATE KEY UPDATE" statement.
)
switch value := data.(type) {
case Result:
listMap = value.List()

case Record:
listMap = List{value.Map()}

case List:
listMap = value
for i, v := range listMap {
listMap[i] = ConvertDataForTableRecord(v)
}

case Map:
listMap = List{ConvertDataForTableRecord(value)}

default:
var (
rv = reflect.ValueOf(data)
kind = rv.Kind()
)
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
// If it's slice type, it then converts it to List type.
case reflect.Slice, reflect.Array:
listMap = make(List, rv.Len())
for i := 0; i < rv.Len(); i++ {
listMap[i] = ConvertDataForTableRecord(rv.Index(i).Interface())
}

case reflect.Map:
listMap = List{ConvertDataForTableRecord(value)}

case reflect.Struct:
if v, ok := value.(apiInterfaces); ok {
var (
array = v.Interfaces()
list = make(List, len(array))
)
for i := 0; i < len(array); i++ {
list[i] = ConvertDataForTableRecord(array[i])
}
listMap = list
} else {
listMap = List{ConvertDataForTableRecord(value)}
}

default:
return result, gerror.New(fmt.Sprint("unsupported list type:", kind))
}
}
if len(listMap) < 1 {
return result, gerror.New("data list cannot be empty")
}
if link == nil {
if link, err = c.MasterLink(); err != nil {
return
}
}
// Handle the field names and place holders.
for k, _ := range listMap[0] {
for k, _ := range list[0] {
keys = append(keys, k)
}
// Prepare the batch result pointer.
var (
charL, charR = c.db.GetChars()
batchResult = new(SqlResult)
keysStr = charL + strings.Join(keys, charR+","+charL) + charR
operation = GetInsertOperationByOption(option)
updateStr = ""
operation = GetInsertOperationByOption(option.InsertOption)
)
if option == insertOptionSave {
for _, k := range keys {
// If it's SAVE operation,
// do not automatically update the creating time.
if c.isSoftCreatedFiledName(k) {
continue
}
if len(updateStr) > 0 {
updateStr += ","
}
updateStr += fmt.Sprintf(
"%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
)
}
updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", updateStr)
}
if batch <= 0 {
batch = defaultBatchNumber
if option.InsertOption == insertOptionSave {
onDuplicateStr = c.formatOnDuplicate(keys, option)
}
var (
listMapLen = len(listMap)
listLength = len(list)
valueHolder = make([]string, 0)
)
for i := 0; i < listMapLen; i++ {
for i := 0; i < listLength; i++ {
values = values[:0]
// Note that the map type is unordered,
// so it should use slice+key to retrieve the value.
for _, k := range keys {
if s, ok := listMap[i][k].(Raw); ok {
if s, ok := list[i][k].(Raw); ok {
values = append(values, gconv.String(s))
} else {
values = append(values, "?")
params = append(params, listMap[i][k])
params = append(params, list[i][k])
}
}
valueHolder = append(valueHolder, "("+gstr.Join(values, ",")+")")
// Batch package checks: It meets the batch number or it is the last element.
if len(valueHolder) == batch || (i == listMapLen-1 && len(valueHolder) > 0) {
if len(valueHolder) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) {
r, err := c.db.DoExec(ctx, link, fmt.Sprintf(
"%s INTO %s(%s) VALUES%s %s",
operation, table, keysStr,
operation, c.QuotePrefixTableName(table), keysStr,
gstr.Join(valueHolder, ","),
updateStr,
onDuplicateStr,
), params...)
if err != nil {
return r, err
Expand All @@ -511,6 +429,52 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, data inter
return batchResult, nil
}

func (c *Core) formatOnDuplicate(columns []string, option DoInsertOption) string {
var (
onDuplicateStr string
)
if option.OnDuplicateStr != "" {
onDuplicateStr = option.OnDuplicateStr
} else if len(option.OnDuplicateMap) > 0 {
for k, v := range option.OnDuplicateMap {
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
switch v.(type) {
case Raw, *Raw:
onDuplicateStr += fmt.Sprintf(
"%s=%s",
c.QuoteWord(k),
v,
)
default:
onDuplicateStr += fmt.Sprintf(
"%s=VALUES(%s)",
c.QuoteWord(k),
c.QuoteWord(gconv.String(v)),
)
}
}
} else {
for _, column := range columns {
// If it's SAVE operation,
// do not automatically update the creating time.
if c.isSoftCreatedFilledName(column) {
continue
}
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
onDuplicateStr += fmt.Sprintf(
"%s=VALUES(%s)",
c.QuoteWord(column),
c.QuoteWord(column),
)
}
}
return fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", onDuplicateStr)
}

// Update does "UPDATE ... " statement for the table.
//
// The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
Expand Down Expand Up @@ -711,8 +675,8 @@ func (c *Core) HasTable(name string) (bool, error) {
return false, nil
}

// isSoftCreatedFiledName checks and returns whether given filed name is an automatic-filled created time.
func (c *Core) isSoftCreatedFiledName(fieldName string) bool {
// isSoftCreatedFilledName checks and returns whether given filed name is an automatic-filled created time.
func (c *Core) isSoftCreatedFilledName(fieldName string) bool {
if fieldName == "" {
return false
}
Expand Down
100 changes: 16 additions & 84 deletions database/gdb/gdb_driver_oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,95 +264,40 @@ func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[
return
}

func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch int) (result sql.Result, err error) {
func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) {
var (
keys []string
values []string
params []interface{}
)
listMap := (List)(nil)
switch v := list.(type) {
case Result:
listMap = v.List()
case Record:
listMap = List{v.Map()}
case List:
listMap = v
case Map:
listMap = List{v}
default:
var (
rv = reflect.ValueOf(list)
kind = rv.Kind()
)
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Slice, reflect.Array:
listMap = make(List, rv.Len())
for i := 0; i < rv.Len(); i++ {
listMap[i] = ConvertDataForTableRecord(rv.Index(i).Interface())
}
case reflect.Map:
fallthrough
case reflect.Struct:
listMap = List{ConvertDataForTableRecord(list)}
default:
return result, gerror.New(fmt.Sprint("unsupported list type:", kind))
}
}
if len(listMap) < 1 {
return result, gerror.New("empty data list")
}
if link == nil {
if link, err = d.MasterLink(); err != nil {
return
}
}
// Retrieve the table fields and length.
holders := []string(nil)
for k, _ := range listMap[0] {
var (
listLength = len(list)
valueHolder = make([]string, 0)
)
for k, _ := range list[0] {
keys = append(keys, k)
holders = append(holders, "?")
valueHolder = append(valueHolder, "?")
}
var (
batchResult = new(SqlResult)
charL, charR = d.db.GetChars()
keyStr = charL + strings.Join(keys, charL+","+charR) + charR
valueHolderStr = strings.Join(holders, ",")
valueHolderStr = strings.Join(valueHolder, ",")
)
if option != insertOptionDefault {
for _, v := range listMap {
r, err := d.DoInsert(ctx, link, table, v, option, 1)
if err != nil {
return r, err
}

if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.result = r
batchResult.affected += n
}
}
return batchResult, nil
}

if batch <= 0 {
batch = defaultBatchNumber
}
// Format "INSERT...INTO..." statement.
intoStr := make([]string, 0)
for i := 0; i < len(listMap); i++ {
for i := 0; i < len(list); i++ {
for _, k := range keys {
params = append(params, listMap[i][k])
params = append(params, list[i][k])
}
values = append(values, valueHolderStr)
intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr))
if len(intoStr) == batch {
r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
intoStr = append(intoStr, fmt.Sprintf("INTO %s(%s) VALUES(%s)", table, keyStr, valueHolderStr))
if len(intoStr) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) {
r, err := d.DoExec(ctx, link, fmt.Sprintf(
"INSERT ALL %s SELECT * FROM DUAL",
strings.Join(intoStr, " "),
), params...)
if err != nil {
return r, err
}
Expand All @@ -366,18 +311,5 @@ func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, li
intoStr = intoStr[:0]
}
}
// The leftover data.
if len(intoStr) > 0 {
r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.result = r
batchResult.affected += n
}
}
return batchResult, nil
}
3 changes: 3 additions & 0 deletions database/gdb/gdb_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,15 @@ func ConvertDataForTableRecord(value interface{}) map[string]interface{} {
// Convert the value to JSON.
data[k], _ = json.Marshal(v)
}

case reflect.Struct:
switch v.(type) {
case time.Time, *time.Time, gtime.Time, *gtime.Time:
continue

case Counter, *Counter:
continue

default:
// Use string conversion in default.
if s, ok := v.(apiString); ok {
Expand Down
Loading

0 comments on commit d450de8

Please sign in to comment.