forked from gohouse/gorose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
session.go
485 lines (434 loc) · 12.1 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
package gorose
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/gohouse/t"
)
// Session ...
type Session struct {
IEngin
IBinder
master *sql.DB
tx *sql.Tx
slave *sql.DB
lastInsertId int64
sqlLogs []string
lastSql string
union interface{}
transaction bool
err error
}
var _ ISession = (*Session)(nil)
// NewSession : 初始化 Session
func NewSession(e IEngin) *Session {
var s = new(Session)
s.IEngin = e
// 初始化 IBinder
s.SetIBinder(NewBinder())
s.master = e.GetExecuteDB()
s.slave = e.GetQueryDB()
return s
}
func (s *Session) Close() {
s.master.Close()
s.slave.Close()
}
// GetIEngin 获取engin
func (s *Session) GetIEngin() IEngin {
return s.IEngin
}
// GetDriver 获取驱动
func (s *Session) SetIEngin(ie IEngin) {
s.IEngin = ie
}
// Bind : 传入绑定结果的对象, 参数一为对象, 可以是 struct, gorose.MapRow 或对应的切片
// 如果是做非query操作,第一个参数也可以仅仅指定为字符串表名
func (s *Session) Bind(tab interface{}) ISession {
//fmt.Println(tab, NewBinder(tab))
//s.SetIBinder(NewBinder(tab))
s.GetIBinder().SetBindOrigin(tab)
s.err = s.IBinder.BindParse(s.GetIEngin().GetPrefix())
return s
}
// GetBinder 获取绑定对象
func (s *Session) GetErr() error {
return s.err
}
// GetBinder 获取绑定对象
func (s *Session) SetIBinder(ib IBinder) {
s.IBinder = ib
}
// GetBinder 获取绑定对象
func (s *Session) GetIBinder() IBinder {
return s.IBinder
}
// GetBinder 获取绑定对象
func (s *Session) ResetBinderResult() {
_ = s.IBinder.BindParse(s.GetIEngin().GetPrefix())
}
// GetTableName 获取解析后的名字, 提供给orm使用
// 为什么要在这里重复添加该方法, 而不是直接继承 IBinder 的方法呢?
// 是因为, 这里涉及到表前缀的问题, 只能通过session来传递, 所以IOrm就可以选择直接继承
func (s *Session) GetTableName() (string, error) {
//err := s.IBinder.BindParse(s.GetIEngin().GetPrefix())
//fmt.Println(s.GetIBinder())
return s.GetIBinder().GetBindName(), s.err
}
// Begin ...
func (s *Session) Begin() (err error) {
s.tx, err = s.master.Begin()
s.SetTransaction(true)
return
}
// Rollback ...
func (s *Session) Rollback() (err error) {
err = s.tx.Rollback()
s.tx = nil
s.SetTransaction(false)
return
}
// Commit ...
func (s *Session) Commit() (err error) {
err = s.tx.Commit()
s.tx = nil
s.SetTransaction(false)
return
}
// Transaction ...
func (s *Session) Transaction(closers ...func(ses ISession) error) (err error) {
err = s.Begin()
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return err
}
for _, closer := range closers {
err = closer(s)
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
_ = s.Rollback()
return
}
}
return s.Commit()
}
// Query ...
func (s *Session) Query(sqlstring string, args ...interface{}) (result []Data, err error) {
// 记录开始时间
start := time.Now()
//withRunTimeContext(func() {
if s.err != nil {
err = s.err
s.GetIEngin().GetLogger().Error(err.Error())
}
// 记录sqlLog
s.lastSql = fmt.Sprint(sqlstring, ", ", args)
//if s.IfEnableSqlLog() {
// s.sqlLogs = append(s.sqlLogs, s.lastSql)
//}
var stmt *sql.Stmt
// 如果是事务, 则从主库中读写
if s.tx == nil {
stmt, err = s.slave.Prepare(sqlstring)
} else {
stmt, err = s.tx.Prepare(sqlstring)
}
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return
}
// make sure we always close rows
defer rows.Close()
err = s.scan(rows)
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return
}
//}, func(duration time.Duration) {
// //if duration.Seconds() > 1 {
// // s.GetIEngin().GetLogger().Slow(s.LastSql(), duration)
// //} else {
// // s.GetIEngin().GetLogger().Sql(s.LastSql(), duration)
// //}
//})
timeduration := time.Since(start)
//if timeduration.Seconds() > 1 {
s.GetIEngin().GetLogger().Slow(s.LastSql(), timeduration)
//} else {
s.GetIEngin().GetLogger().Sql(s.LastSql(), timeduration)
//}
result = s.GetIBinder().GetBindAll()
return
}
// Execute ...
func (s *Session) Execute(sqlstring string, args ...interface{}) (rowsAffected int64, err error) {
// 记录开始时间
start := time.Now()
//withRunTimeContext(func() {
// err = s.GetIBinder().BindParse(s.GetIEngin().GetPrefix())
if s.err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return
}
s.lastSql = fmt.Sprint(sqlstring, ", ", args)
//// 记录sqlLog
//if s.IfEnableSqlLog() {
// s.sqlLogs = append(s.sqlLogs, s.lastSql)
//}
var operType = strings.ToLower(sqlstring[0:6])
if operType == "select" {
s.GetIEngin().GetLogger().Error(err.Error())
err = errors.New("Execute does not allow select operations, please use Query")
return
}
var stmt *sql.Stmt
if s.tx == nil {
stmt, err = s.master.Prepare(sqlstring)
} else {
stmt, err = s.tx.Prepare(sqlstring)
}
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return
}
//var err error
defer stmt.Close()
result, err := stmt.Exec(args...)
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return
}
if operType == "insert" {
// get last insert id
lastInsertId, err := result.LastInsertId()
if err == nil {
s.lastInsertId = lastInsertId
} else {
s.GetIEngin().GetLogger().Error(err.Error())
}
}
// get rows affected
rowsAffected, err = result.RowsAffected()
timeduration := time.Since(start)
//}, func(duration time.Duration) {
if timeduration.Seconds() > 1 {
s.GetIEngin().GetLogger().Slow(s.LastSql(), timeduration)
} else {
s.GetIEngin().GetLogger().Sql(s.LastSql(), timeduration)
}
//})
return
}
// LastInsertId ...
func (s *Session) LastInsertId() int64 {
return s.lastInsertId
}
// LastSql ...
func (s *Session) LastSql() string {
return s.lastSql
}
func (s *Session) scan(rows *sql.Rows) (err error) {
// 如果不需要绑定, 则需要初始化一下binder
if s.GetIBinder() == nil {
s.SetIBinder(NewBinder())
}
// 检查实多维数组还是一维数组
switch s.GetBindType() {
case OBJECT_STRING:
err = s.scanAll(rows)
case OBJECT_STRUCT, OBJECT_STRUCT_SLICE:
err = s.scanStructAll(rows)
//case OBJECT_MAP, OBJECT_MAP_T:
// err = s.scanMap(rows, s.GetBindResult())
case OBJECT_MAP, OBJECT_MAP_T, OBJECT_MAP_SLICE, OBJECT_MAP_SLICE_T:
err = s.scanMapAll(rows)
case OBJECT_NIL:
err = s.scanAll(rows)
default:
err = errors.New("Bind value error")
}
return
}
//func (s *Session) scanMap(rows *sql.Rows, dst interface{}) (err error) {
// return s.scanMapAll(rows, dst)
//}
func (s *Session) scanMapAll(rows *sql.Rows) (err error) {
var columns []string
// 获取查询的所有字段
if columns, err = rows.Columns(); err != nil {
return
}
count := len(columns)
for rows.Next() {
// 定义要绑定的结果集
values := make([]interface{}, count)
scanArgs := make([]interface{}, count)
for i := 0; i < count; i++ {
scanArgs[i] = &values[i]
}
// 获取结果
_ = rows.Scan(scanArgs...)
// 定义预设的绑定对象
//fmt.Println(reflect.TypeOf(s.GetBindResult()).Kind())
var bindResultTmp = reflect.MakeMap(reflect.Indirect(reflect.ValueOf(s.GetBindResult())).Type())
//// 定义union操作的map返回
//var unionTmp = map[string]interface{}{}
for i, col := range columns {
var v interface{}
val := values[i]
if b, ok := val.([]byte); ok {
v = string(b)
} else {
v = val
}
// 如果是union操作就不需要绑定数据直接返回, 否则就绑定数据
//TODO 这里可能有点问题, 比如在group时, 返回的结果不止一条, 这里直接返回的就是第一条
// 默认其实只是取了第一条, 满足常规的 union 操作(count,sum,max,min,avg)而已
// 后边需要再行完善, 以便group时使用
// 具体完善方法: 就是这里断点去掉, 不直接绑定union, 新增一个map,将结果放在map中,在方法最后统一返回
if s.GetUnion() != nil {
s.union = v
return
// 以下上通用解决方法
//unionTmp[col] = v
//s.union = unionTmp
} else {
br := reflect.Indirect(reflect.ValueOf(s.GetBindResult()))
switch s.GetBindType() {
case OBJECT_MAP_T, OBJECT_MAP_SLICE_T: // t.T类型
// 绑定到一条数据结果对象上,方便其他地方的调用,永远存储最新一条
br.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(t.New(v)))
// 跟上一行干的事是一样的, 只不过防止上一行的数据被后续的数据改变, 而无法提供给下边多条数据报错的需要
if s.GetBindType() == OBJECT_MAP_SLICE || s.GetBindType() == OBJECT_MAP_SLICE_T {
bindResultTmp.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(t.New(v)))
}
default: // 普通类型map[string]interface{}, 具体代码注释参照 上一个 case
br.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(v))
if s.GetBindType() == OBJECT_MAP_SLICE || s.GetBindType() == OBJECT_MAP_SLICE_T {
bindResultTmp.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(v))
}
}
}
}
// 如果是union操作就不需要绑定数据直接返回, 否则就绑定数据
if s.GetUnion() == nil {
// 如果是多条数据集, 就插入到对应的结果集slice上
if s.GetBindType() == OBJECT_MAP_SLICE || s.GetBindType() == OBJECT_MAP_SLICE_T {
s.GetBindResultSlice().Set(reflect.Append(s.GetBindResultSlice(), bindResultTmp))
}
}
}
return
}
// ScanAll scans all sql result rows into a slice of structs.
// It reads all rows and closes rows when finished.
// dst should be a pointer to a slice of the appropriate type.
// The new results will be appended to any existing data in dst.
func (s *Session) scanStructAll(rows *sql.Rows) error {
// check if there is data waiting
//if !rows.Next() {
// if err := rows.Err(); err != nil {
// s.GetIEngin().GetLogger().Error(err.Error())
// return err
// }
// return sql.ErrNoRows
//}
var sfs = structForScan(s.GetBindResult())
for rows.Next() {
if s.GetUnion() != nil {
var union interface{}
err := rows.Scan(&union)
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return err
}
s.union = union
return err
}
// scan it
//fmt.Printf("%#v \n",structForScan(s.GetBindResult()))
err := rows.Scan(sfs...)
if err != nil {
s.GetIEngin().GetLogger().Error(err.Error())
return err
}
// 如果是union操作就不需要绑定数据直接返回, 否则就绑定数据
if s.GetUnion() == nil {
// 如果是多条数据集, 就插入到对应的结果集slice上
if s.GetBindType() == OBJECT_STRUCT_SLICE {
// add to the result slice
s.GetBindResultSlice().Set(reflect.Append(s.GetBindResultSlice(),
reflect.Indirect(reflect.ValueOf(s.GetBindResult()))))
}
}
}
return rows.Err()
}
func (s *Session) scanAll(rows *sql.Rows) (err error) {
var columns []string
// 获取查询的所有字段
if columns, err = rows.Columns(); err != nil {
return
}
count := len(columns)
var result = []Data{}
for rows.Next() {
// 定义要绑定的结果集
values := make([]interface{}, count)
scanArgs := make([]interface{}, count)
for i := 0; i < count; i++ {
scanArgs[i] = &values[i]
}
// 获取结果
_ = rows.Scan(scanArgs...)
// 定义预设的绑定对象
var resultTmp = Data{}
//// 定义union操作的map返回
//var unionTmp = map[string]interface{}{}
for i, col := range columns {
var v interface{}
val := values[i]
if b, ok := val.([]byte); ok {
v = string(b)
} else {
v = val
}
if s.GetUnion() != nil {
s.union = v
return
// 以下上通用解决方法
//unionTmp[col] = v
//s.union = unionTmp
}
resultTmp[col] = v
}
result = append(result, resultTmp)
}
s.IBinder.SetBindAll(result)
return
}
// SetUnion ...
func (s *Session) SetUnion(u interface{}) {
s.union = u
}
// GetUnion ...
func (s *Session) GetUnion() interface{} {
return s.union
}
// SetTransaction ...
func (s *Session) SetTransaction(b bool) {
s.transaction = b
}
// GetTransaction 提供给 orm 使用的, 方便reset操作
func (s *Session) GetTransaction() bool {
return s.transaction
}