diff --git a/cmd/cmd.go b/cmd/cmd.go index 1abb787..b3b6ccc 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -38,6 +38,7 @@ import ( "github.com/cectc/dbpack/pkg/executor" "github.com/cectc/dbpack/pkg/filter" _ "github.com/cectc/dbpack/pkg/filter/audit_log" + _ "github.com/cectc/dbpack/pkg/filter/crypto" _ "github.com/cectc/dbpack/pkg/filter/dt" _ "github.com/cectc/dbpack/pkg/filter/metrics" dbpackHttp "github.com/cectc/dbpack/pkg/http" diff --git a/docker/conf/config_rws.yaml b/docker/conf/config_rws.yaml index a591fc6..89bac2d 100644 --- a/docker/conf/config_rws.yaml +++ b/docker/conf/config_rws.yaml @@ -19,6 +19,8 @@ executors: weight: r0w10 - name: employees-slave weight: r10w0 + filters: + - cryptoFilter data_source_cluster: - name: employees-master @@ -46,6 +48,13 @@ filters: appid: svc lock_retry_interval: 50ms lock_retry_times: 30 + - name: cryptoFilter + kind: CryptoFilter + conf: + column_crypto_list: + - table: departments + columns: ["dept_name"] + aeskey: 123456789abcdefg distributed_transaction: appid: svc diff --git a/docker/conf/config_sdb.yaml b/docker/conf/config_sdb.yaml index 4115dfc..718f57a 100644 --- a/docker/conf/config_sdb.yaml +++ b/docker/conf/config_sdb.yaml @@ -14,6 +14,8 @@ executors: mode: sdb config: data_source_ref: employees + filters: + - cryptoFilter data_source_cluster: - name: employees @@ -50,6 +52,13 @@ filters: # determines if the rotated log files should be compressed using gzip compress: true record_before: true + - name: cryptoFilter + kind: CryptoFilter + conf: + column_crypto_list: + - table: departments + columns: ["dept_name"] + aeskey: 123456789abcdefg distributed_transaction: appid: svc diff --git a/docker/scripts/init.sql b/docker/scripts/init.sql index 54dcbce..4a820c8 100644 --- a/docker/scripts/init.sql +++ b/docker/scripts/init.sql @@ -52,7 +52,7 @@ CREATE TABLE employees ( CREATE TABLE departments ( `id` bigint NOT NULL AUTO_INCREMENT, dept_no CHAR(4) NOT NULL, - dept_name VARCHAR(40) NOT NULL, + dept_name VARCHAR(100) NOT NULL, PRIMARY KEY (`id`), UNIQUE KEY (dept_name) ); diff --git a/go.mod b/go.mod index 66008f7..dc22632 100644 --- a/go.mod +++ b/go.mod @@ -74,7 +74,7 @@ require ( github.com/opentracing/opentracing-go v1.1.0 // indirect github.com/pingcap/failpoint v0.0.0-20210316064728-7acb0f0a3dfd // indirect github.com/pingcap/kvproto v0.0.0-20210806074406-317f69fb54b4 // indirect - github.com/pingcap/parser v0.0.0-20210831085004-b5390aa83f65 // indirect + github.com/pingcap/parser v0.0.0-20210831085004-b5390aa83f65 github.com/pingcap/tipb v0.0.0-20210708040514-0f154bb0dc0f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect diff --git a/pkg/executor/misc.go b/pkg/executor/misc.go index 7b76121..d209aea 100644 --- a/pkg/executor/misc.go +++ b/pkg/executor/misc.go @@ -17,8 +17,11 @@ package executor import ( + "io" "strings" + "github.com/cectc/dbpack/pkg/mysql" + "github.com/cectc/dbpack/pkg/proto" "github.com/cectc/dbpack/third_party/parser/ast" driver "github.com/cectc/dbpack/third_party/types/parser_driver" ) @@ -43,3 +46,69 @@ func shouldStartTransaction(stmt *ast.SetStmt) (shouldStartTransaction bool) { } return } + +func decodeTextResult(result proto.Result) (proto.Result, error) { + if result != nil { + if mysqlResult, ok := result.(*mysql.Result); ok { + if mysqlResult.Rows != nil { + var rows []proto.Row + for { + row, err := mysqlResult.Rows.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + textRow := &mysql.TextRow{Row: row} + _, err = textRow.Decode() + if err != nil { + return nil, err + } + rows = append(rows, textRow) + } + decodedRow := &mysql.DecodedResult{ + Fields: mysqlResult.Fields, + AffectedRows: mysqlResult.AffectedRows, + InsertId: mysqlResult.InsertId, + Rows: rows, + } + return decodedRow, nil + } + } + } + return result, nil +} + +func decodeBinaryResult(result proto.Result) (proto.Result, error) { + if result != nil { + if mysqlResult, ok := result.(*mysql.Result); ok { + if mysqlResult.Rows != nil { + var rows []proto.Row + for { + row, err := mysqlResult.Rows.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + binaryRow := &mysql.BinaryRow{Row: row} + _, err = binaryRow.Decode() + if err != nil { + return nil, err + } + rows = append(rows, binaryRow) + } + decodedRow := &mysql.DecodedResult{ + Fields: mysqlResult.Fields, + AffectedRows: mysqlResult.AffectedRows, + InsertId: mysqlResult.InsertId, + Rows: rows, + } + return decodedRow, nil + } + } + } + return result, nil +} diff --git a/pkg/executor/read_write_splitting.go b/pkg/executor/read_write_splitting.go index 297fa80..2f71721 100644 --- a/pkg/executor/read_write_splitting.go +++ b/pkg/executor/read_write_splitting.go @@ -28,16 +28,13 @@ import ( "github.com/cectc/dbpack/pkg/filter" "github.com/cectc/dbpack/pkg/lb" "github.com/cectc/dbpack/pkg/log" + "github.com/cectc/dbpack/pkg/misc" "github.com/cectc/dbpack/pkg/mysql" "github.com/cectc/dbpack/pkg/proto" "github.com/cectc/dbpack/pkg/resource" "github.com/cectc/dbpack/pkg/tracing" "github.com/cectc/dbpack/third_party/parser/ast" - "github.com/cectc/dbpack/third_party/parser/model" -) - -const ( - hintUseDB = "UseDB" + "github.com/cectc/dbpack/third_party/parser/format" ) type ReadWriteSplittingExecutor struct { @@ -152,19 +149,44 @@ func (executor *ReadWriteSplittingExecutor) ExecuteFieldList(ctx context.Context return nil, errors.New("unimplemented COM_FIELD_LIST in read write splitting mode") } -func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(ctx context.Context, sql string) (proto.Result, uint16, error) { - var ( - db *DataSourceBrief - tx proto.Tx - result proto.Result - err error - ) - +func (executor *ReadWriteSplittingExecutor) ExecutorComQuery( + ctx context.Context, _ string) (result proto.Result, warns uint16, err error) { spanCtx, span := tracing.GetTraceSpan(ctx, tracing.RWSComQuery) defer span.End() + if err = executor.doPreFilter(spanCtx); err != nil { + return nil, 0, err + } + defer func() { + if err == nil { + result, err = decodeTextResult(result) + if err != nil { + span.RecordError(err) + return + } + err = executor.doPostFilter(spanCtx, result) + } else { + span.RecordError(err) + } + }() + + var ( + db *DataSourceBrief + tx proto.Tx + sb strings.Builder + ) + connectionID := proto.ConnectionID(spanCtx) queryStmt := proto.QueryStmt(spanCtx) + if err := queryStmt.Restore(format.NewRestoreCtx(format.RestoreStringSingleQuotes| + format.RestoreKeyWordUppercase| + format.RestoreStringWithoutDefaultCharset, &sb)); err != nil { + return nil, 0, err + } + sql := sb.String() + spanCtx = proto.WithSqlText(spanCtx, sql) + + log.Debugf("connectionID: %d, query: %s", connectionID, sql) switch stmt := queryStmt.(type) { case *ast.SetStmt: if shouldStartTransaction(stmt) { @@ -246,7 +268,7 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(ctx context.Context return tx.Query(spanCtx, sql) } withSlaveCtx := proto.WithSlave(spanCtx) - if has, dsName := hasUseDBHint(stmt.TableHints); has { + if has, dsName := misc.HasUseDBHint(stmt.TableHints); has { protoDB := resource.GetDBManager().GetDB(dsName) if protoDB == nil { log.Debugf("data source %d not found", dsName) @@ -271,11 +293,29 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(ctx context.Context } } -func (executor *ReadWriteSplittingExecutor) ExecutorComStmtExecute(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) { +func (executor *ReadWriteSplittingExecutor) ExecutorComStmtExecute( + ctx context.Context, stmt *proto.Stmt) (result proto.Result, warns uint16, err error) { spanCtx, span := tracing.GetTraceSpan(ctx, tracing.RWSComStmtExecute) defer span.End() + if err = executor.doPreFilter(spanCtx); err != nil { + return nil, 0, err + } + defer func() { + if err == nil { + result, err = decodeBinaryResult(result) + if err != nil { + span.RecordError(err) + return + } + err = executor.doPostFilter(spanCtx, result) + } else { + span.RecordError(err) + } + }() + connectionID := proto.ConnectionID(spanCtx) + log.Debugf("connectionID: %d, prepare: %s", connectionID, stmt.SqlText) txi, ok := executor.localTransactionMap.Load(connectionID) if ok { // in local transaction @@ -288,7 +328,7 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComStmtExecute(ctx context.C return db.DB.ExecuteStmt(proto.WithMaster(spanCtx), stmt) case *ast.SelectStmt: var db *DataSourceBrief - if has, dsName := hasUseDBHint(st.TableHints); has { + if has, dsName := misc.HasUseDBHint(st.TableHints); has { protoDB := resource.GetDBManager().GetDB(dsName) if protoDB == nil { log.Debugf("data source %d not found", dsName) @@ -338,14 +378,3 @@ func (executor *ReadWriteSplittingExecutor) doPostFilter(ctx context.Context, re } return nil } - -func hasUseDBHint(hints []*ast.TableOptimizerHint) (bool, string) { - for _, hint := range hints { - if strings.EqualFold(hint.HintName.String(), hintUseDB) { - hintData := hint.HintData.(model.CIStr) - ds := hintData.String() - return true, ds - } - } - return false, "" -} diff --git a/pkg/executor/sharding.go b/pkg/executor/sharding.go index 7003edf..dd9df50 100644 --- a/pkg/executor/sharding.go +++ b/pkg/executor/sharding.go @@ -240,11 +240,20 @@ func (executor *ShardingExecutor) ExecutorComQuery(ctx context.Context, sql stri return plan.Execute(spanCtx) } -func (executor *ShardingExecutor) ExecutorComStmtExecute(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) { +func (executor *ShardingExecutor) ExecutorComStmtExecute( + ctx context.Context, stmt *proto.Stmt) (result proto.Result, warns uint16, err error) { + if err = executor.doPreFilter(ctx); err != nil { + return nil, 0, err + } + defer func() { + if err == nil { + err = executor.doPostFilter(ctx, result) + } + }() + var ( args []interface{} plan proto.Plan - err error ) spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SHDComStmtExecute) diff --git a/pkg/executor/single_db.go b/pkg/executor/single_db.go index 53d149c..12c9533 100644 --- a/pkg/executor/single_db.go +++ b/pkg/executor/single_db.go @@ -19,6 +19,7 @@ package executor import ( "context" "encoding/json" + "strings" "sync" "github.com/pkg/errors" @@ -30,6 +31,7 @@ import ( "github.com/cectc/dbpack/pkg/resource" "github.com/cectc/dbpack/pkg/tracing" "github.com/cectc/dbpack/third_party/parser/ast" + "github.com/cectc/dbpack/third_party/parser/format" ) type SingleDBExecutor struct { @@ -121,21 +123,46 @@ func (executor *SingleDBExecutor) ExecuteFieldList(ctx context.Context, table, w return db.ExecuteFieldList(ctx, table, wildcard) } -func (executor *SingleDBExecutor) ExecutorComQuery(ctx context.Context, sql string) (proto.Result, uint16, error) { - var ( - db proto.DB - tx proto.Tx - result proto.Result - err error - ) +func (executor *SingleDBExecutor) ExecutorComQuery( + ctx context.Context, _ string) (result proto.Result, warns uint16, err error) { spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SDBComQuery) defer span.End() + if err = executor.doPreFilter(spanCtx); err != nil { + return nil, 0, err + } + defer func() { + if err == nil { + result, err = decodeTextResult(result) + if err != nil { + span.RecordError(err) + return + } + err = executor.doPostFilter(spanCtx, result) + } else { + span.RecordError(err) + } + }() + + var ( + db proto.DB + tx proto.Tx + sb strings.Builder + ) + connectionID := proto.ConnectionID(spanCtx) queryStmt := proto.QueryStmt(spanCtx) if queryStmt == nil { return nil, 0, errors.New("query stmt should not be nil") } + if err := queryStmt.Restore(format.NewRestoreCtx(format.RestoreStringSingleQuotes| + format.RestoreKeyWordUppercase| + format.RestoreStringWithoutDefaultCharset, &sb)); err != nil { + return nil, 0, err + } + sql := sb.String() + spanCtx = proto.WithSqlText(spanCtx, sql) + log.Debugf("connectionID: %d, query: %s", connectionID, sql) db = resource.GetDBManager().GetDB(executor.dataSource) switch stmt := queryStmt.(type) { @@ -198,11 +225,28 @@ func (executor *SingleDBExecutor) ExecutorComQuery(ctx context.Context, sql stri } } -func (executor *SingleDBExecutor) ExecutorComStmtExecute(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) { +func (executor *SingleDBExecutor) ExecutorComStmtExecute( + ctx context.Context, stmt *proto.Stmt) (result proto.Result, warns uint16, err error) { spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SDBComStmtExecute) defer span.End() - connectionID := proto.ConnectionID(spanCtx) + if err = executor.doPreFilter(spanCtx); err != nil { + return nil, 0, err + } + defer func() { + if err == nil { + result, err = decodeBinaryResult(result) + if err != nil { + span.RecordError(err) + return + } + err = executor.doPostFilter(spanCtx, result) + } else { + span.RecordError(err) + } + }() + + connectionID := proto.ConnectionID(ctx) log.Debugf("connectionID: %d, prepare: %s", connectionID, stmt.SqlText) txi, ok := executor.localTransactionMap.Load(connectionID) if ok { diff --git a/pkg/filter/crypto/filter.go b/pkg/filter/crypto/filter.go new file mode 100644 index 0000000..fb2ea1a --- /dev/null +++ b/pkg/filter/crypto/filter.go @@ -0,0 +1,407 @@ +/* + * Copyright 2022 CECTC, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package crypto + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + + "github.com/pingcap/errors" + + "github.com/cectc/dbpack/pkg/constant" + "github.com/cectc/dbpack/pkg/filter" + "github.com/cectc/dbpack/pkg/log" + "github.com/cectc/dbpack/pkg/misc" + "github.com/cectc/dbpack/pkg/mysql" + "github.com/cectc/dbpack/pkg/proto" + "github.com/cectc/dbpack/third_party/parser/ast" + "github.com/cectc/dbpack/third_party/parser/format" + driver "github.com/cectc/dbpack/third_party/types/parser_driver" +) + +const ( + cryptoFilter = "CryptoFilter" + aesIV = "awesome789dbpack" +) + +type _factory struct{} + +func (factory *_factory) NewFilter(config map[string]interface{}) (proto.Filter, error) { + var ( + err error + content []byte + ) + if content, err = json.Marshal(config); err != nil { + return nil, errors.Wrap(err, "marshal crypto filter config failed.") + } + v := &struct { + ColumnCryptoList []*ColumnCrypto `yaml:"column_crypto_list" json:"column_crypto_list"` + }{} + if err = json.Unmarshal(content, &v); err != nil { + log.Errorf("unmarshal crypto filter failed, %s", err) + return nil, err + } + + return &_filter{ColumnConfigs: v.ColumnCryptoList}, nil +} + +type _filter struct { + ColumnConfigs []*ColumnCrypto +} + +type ColumnCrypto struct { + Table string + Columns []string + AesKey string +} + +type columnIndex struct { + Column string + Index int +} + +func (f *_filter) GetKind() string { + return cryptoFilter +} + +func (f *_filter) PreHandle(ctx context.Context) error { + commandType := proto.CommandType(ctx) + switch commandType { + case constant.ComQuery: + stmt := proto.QueryStmt(ctx) + switch stmtNode := stmt.(type) { + case *ast.InsertStmt: + config, err := f.checkInsertTable(stmtNode) + if err != nil { + return err + } + if config != nil { + columns, err := retrieveNeedEncryptionInsertColumns(stmtNode, config) + if err != nil { + return err + } + if len(columns) != 0 { + return encryptInsertValues(columns, config, stmtNode.Lists) + } + } + case *ast.UpdateStmt: + config, err := f.checkUpdateTable(stmtNode) + if err != nil { + return err + } + if config != nil { + return encryptUpdateValues(stmtNode, config) + } + default: + return nil + } + case constant.ComStmtExecute: + stmt := proto.PrepareStmt(ctx) + if stmt == nil { + return errors.New("prepare stmt should not be nil") + } + switch stmtNode := stmt.StmtNode.(type) { + case *ast.InsertStmt: + config, err := f.checkInsertTable(stmtNode) + if err != nil { + return err + } + if config != nil { + columns, err := retrieveNeedEncryptionInsertColumns(stmtNode, config) + if err != nil { + return err + } + if len(columns) != 0 { + return encryptBindVars(columns, config, &stmt.BindVars) + } + } + case *ast.UpdateStmt: + config, err := f.checkUpdateTable(stmtNode) + if err != nil { + return err + } + if config != nil { + columns, err := retrieveNeedEncryptionUpdateColumns(stmtNode, config) + if err != nil { + return err + } + if len(columns) != 0 { + return encryptBindVars(columns, config, &stmt.BindVars) + } + } + default: + return nil + } + } + return nil +} + +func (f *_filter) PostHandle(ctx context.Context, result proto.Result) error { + commandType := proto.CommandType(ctx) + switch commandType { + case constant.ComQuery: + stmt := proto.QueryStmt(ctx) + if stmtNode, ok := stmt.(*ast.SelectStmt); ok { + if decodedResult, is := result.(*mysql.DecodedResult); is && len(decodedResult.Rows) > 0 { + config, err := f.checkSelectTable(stmtNode) + if err != nil { + log.Error(err) + return nil + } + if config != nil { + columns, err := retrieveNeedDecryptionSelectColumns(decodedResult, config) + if err != nil { + log.Error(err) + return nil + } + if len(columns) != 0 { + decryptDecodedResult(decodedResult, config, columns) + } + } + } + } + case constant.ComStmtExecute: + stmt := proto.PrepareStmt(ctx) + if stmt == nil { + return errors.New("prepare stmt should not be nil") + } + if stmtNode, ok := stmt.StmtNode.(*ast.SelectStmt); ok { + if decodedResult, is := result.(*mysql.DecodedResult); is && len(decodedResult.Rows) > 0 { + config, err := f.checkSelectTable(stmtNode) + if err != nil { + log.Error(err) + return nil + } + if config != nil { + columns, err := retrieveNeedDecryptionSelectColumns(decodedResult, config) + if err != nil { + log.Error(err) + return nil + } + if len(columns) != 0 { + decryptDecodedResult(decodedResult, config, columns) + } + } + } + } + } + return nil +} + +func (f _filter) checkInsertTable(insertStmt *ast.InsertStmt) (*ColumnCrypto, error) { + var sb strings.Builder + if err := insertStmt.Table.TableRefs.Left.Restore( + format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreKeyWordUppercase, &sb)); err != nil { + return nil, err + } + tableName := sb.String() + for _, config := range f.ColumnConfigs { + if strings.EqualFold(config.Table, tableName) { + return config, nil + } + } + return nil, nil +} + +func (f _filter) checkUpdateTable(updateStmt *ast.UpdateStmt) (*ColumnCrypto, error) { + var sb strings.Builder + if err := updateStmt.TableRefs.TableRefs.Left.Restore( + format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreKeyWordUppercase, &sb)); err != nil { + return nil, err + } + tableName := sb.String() + for _, config := range f.ColumnConfigs { + if strings.EqualFold(config.Table, tableName) { + return config, nil + } + } + return nil, nil +} + +func (f _filter) checkSelectTable(selectStmt *ast.SelectStmt) (*ColumnCrypto, error) { + var sb strings.Builder + if err := selectStmt.From.TableRefs.Left.Restore( + format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreKeyWordUppercase, &sb)); err != nil { + return nil, err + } + tableName := sb.String() + for _, config := range f.ColumnConfigs { + if strings.EqualFold(config.Table, tableName) { + return config, nil + } + } + return nil, nil +} + +func retrieveNeedEncryptionInsertColumns(insertStmt *ast.InsertStmt, config *ColumnCrypto) ([]*columnIndex, error) { + if insertStmt.Columns == nil { + return nil, errors.New("The column to be inserted must be specified") + } + var result []*columnIndex + for i, column := range insertStmt.Columns { + if contains(config.Columns, column.Name.O) { + result = append(result, &columnIndex{ + Column: column.Name.O, + Index: i, + }) + } + } + return result, nil +} + +func retrieveNeedEncryptionUpdateColumns(updateStmt *ast.UpdateStmt, config *ColumnCrypto) ([]*columnIndex, error) { + var result []*columnIndex + for i, column := range updateStmt.List { + columnName := column.Column.Name.O + if contains(config.Columns, columnName) { + result = append(result, &columnIndex{ + Column: columnName, + Index: i, + }) + } + } + return result, nil +} + +func retrieveNeedDecryptionSelectColumns(decodedResult *mysql.DecodedResult, config *ColumnCrypto) ([]*columnIndex, error) { + var result []*columnIndex + for i, column := range decodedResult.Fields { + if column.Name != "" && contains(config.Columns, column.Name) { + result = append(result, &columnIndex{ + Column: column.Name, + Index: i, + }) + } + } + return result, nil +} + +// encryptInsertValues for com_query +func encryptInsertValues(columns []*columnIndex, config *ColumnCrypto, valueList [][]ast.ExprNode) error { + for _, values := range valueList { + for _, column := range columns { + arg := values[column.Index] + if param, ok := arg.(*driver.ValueExpr); ok { + value := param.GetBytes() + if len(value) != 0 { + encoded, err := misc.AesEncryptCBC(value, []byte(config.AesKey), []byte(aesIV)) + if err != nil { + return errors.Wrapf(err, "Encryption of %s failed", column.Column) + } + val := hex.EncodeToString(encoded) + param.SetBytes([]byte(val)) + } + } + } + } + return nil +} + +// encryptUpdateValues for com_query +func encryptUpdateValues(updateStmt *ast.UpdateStmt, config *ColumnCrypto) error { + for _, column := range updateStmt.List { + columnName := column.Column.Name.O + if contains(config.Columns, columnName) { + arg := column.Expr + if param, ok := arg.(*driver.ValueExpr); ok { + value := param.GetBytes() + if len(value) != 0 { + encoded, err := misc.AesEncryptCBC(value, []byte(config.AesKey), []byte(aesIV)) + if err != nil { + return errors.Wrapf(err, "Encryption of %s failed", column.Column) + } + val := hex.EncodeToString(encoded) + param.SetBytes([]byte(val)) + } + } + } + } + return nil +} + +// encryptBindVars for com_stmt_execute +func encryptBindVars(columns []*columnIndex, config *ColumnCrypto, args *map[string]interface{}) error { + for _, column := range columns { + parameterID := fmt.Sprintf("v%d", column.Index+1) + param := (*args)[parameterID] + if arg, ok := param.(string); ok { + encoded, err := misc.AesEncryptCBC([]byte(arg), []byte(config.AesKey), []byte(aesIV)) + if err != nil { + return errors.Errorf("Encryption of %s failed: %v", column.Column, err) + } + val := hex.EncodeToString(encoded) + (*args)[parameterID] = val + } else if arg, ok := param.([]byte); ok { + encoded, err := misc.AesEncryptCBC(arg, []byte(config.AesKey), []byte(aesIV)) + if err != nil { + return errors.Errorf("Encryption of %s failed: %v", column.Column, err) + } + val := hex.EncodeToString(encoded) + (*args)[parameterID] = []byte(val) + } + } + return nil +} + +func decryptDecodedResult(decodedResult *mysql.DecodedResult, config *ColumnCrypto, columns []*columnIndex) { + for _, row := range decodedResult.Rows { + switch r := row.(type) { + case *mysql.TextRow: + for _, column := range columns { + protoValue := r.Values[column.Index] + if protoValue != nil { + if originalVal, ok := protoValue.Val.([]byte); ok { + if n, err := hex.Decode(originalVal, originalVal); err == nil { + if decodedVal, err := misc.AesDecryptCBC(originalVal[:n], []byte(config.AesKey), []byte(aesIV)); err == nil { + r.Values[column.Index].Val = decodedVal + } + } + } + } + } + case *mysql.BinaryRow: + for _, column := range columns { + protoValue := r.Values[column.Index] + if protoValue != nil { + if originalVal, ok := protoValue.Val.([]byte); ok { + if n, err := hex.Decode(originalVal, originalVal); err == nil { + if decodedVal, err := misc.AesDecryptCBC(originalVal[:n], []byte(config.AesKey), []byte(aesIV)); err == nil { + r.Values[column.Index].Val = decodedVal + } + } + } + } + } + } + } +} + +func contains(s []string, str string) bool { + for _, v := range s { + if strings.EqualFold(v, str) { + return true + } + } + return false +} + +func init() { + filter.RegistryFilterFactory(cryptoFilter, &_factory{}) +} diff --git a/pkg/filter/dt/transaction.go b/pkg/filter/dt/filter_http_transaction.go similarity index 100% rename from pkg/filter/dt/transaction.go rename to pkg/filter/dt/filter_http_transaction.go diff --git a/pkg/listener/mysql.go b/pkg/listener/mysql.go index 75f19c7..7379c2a 100644 --- a/pkg/listener/mysql.go +++ b/pkg/listener/mysql.go @@ -597,7 +597,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data return err } } - if rlt, ok := result.(*mysql.MergeResult); ok { + if rlt, ok := result.(*mysql.DecodedResult); ok { if len(rlt.Fields) == 0 { // A successful callback with no fields means that this was a // DML or other write-only operation. @@ -772,7 +772,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data return err } } - if rlt, ok := result.(*mysql.MergeResult); ok { + if rlt, ok := result.(*mysql.DecodedResult); ok { if len(rlt.Fields) == 0 { // A successful callback with no fields means that this was a // DML or other write-only operation. diff --git a/pkg/misc/crypto.go b/pkg/misc/crypto.go new file mode 100644 index 0000000..7708a4c --- /dev/null +++ b/pkg/misc/crypto.go @@ -0,0 +1,153 @@ +/* + * Copyright 2022 CECTC, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package misc + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "io" + + "github.com/pkg/errors" +) + +func AesEncryptCBC(origData []byte, key []byte, iv []byte) (encrypted []byte, err error) { + var ( + block cipher.Block + blockMode cipher.BlockMode + ) + block, err = aes.NewCipher(key) + if err != nil { + return nil, err + } + origData = pkcs7Padding(origData) + blockMode = cipher.NewCBCEncrypter(block, iv) + encrypted = make([]byte, len(origData)) + blockMode.CryptBlocks(encrypted, origData) + return encrypted, err +} + +func AesDecryptCBC(encrypted []byte, key []byte, iv []byte) (decrypted []byte, err error) { + var ( + block cipher.Block + blockMode cipher.BlockMode + ) + block, err = aes.NewCipher(key) + if err != nil { + return nil, err + } + blockMode = cipher.NewCBCDecrypter(block, iv) + decrypted = make([]byte, len(encrypted)) + blockMode.CryptBlocks(decrypted, encrypted) + decrypted = pkcs7UnPadding(decrypted) + return decrypted, err +} + +func pkcs7Padding(ciphertext []byte) []byte { + padding := aes.BlockSize - len(ciphertext)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padText...) +} + +func pkcs7UnPadding(plantText []byte) []byte { + length := len(plantText) + unPadding := int(plantText[length-1]) + return plantText[:(length - unPadding)] +} + +func AesEncryptECB(origData []byte, key []byte) (encrypted []byte, err error) { + var block cipher.Block + block, err = aes.NewCipher(generateKey(key)) + if err != nil { + return nil, err + } + length := (len(origData) + aes.BlockSize) / aes.BlockSize + plain := make([]byte, length*aes.BlockSize) + copy(plain, origData) + pad := byte(len(plain) - len(origData)) + for i := len(origData); i < len(plain); i++ { + plain[i] = pad + } + encrypted = make([]byte, len(plain)) + for bs, be := 0, block.BlockSize(); bs <= len(origData); bs, be = bs+block.BlockSize(), be+block.BlockSize() { + block.Encrypt(encrypted[bs:be], plain[bs:be]) + } + return encrypted, err +} + +func AesDecryptECB(encrypted []byte, key []byte) (decrypted []byte, err error) { + var block cipher.Block + block, err = aes.NewCipher(generateKey(key)) + if err != nil { + return nil, err + } + decrypted = make([]byte, len(encrypted)) + for bs, be := 0, block.BlockSize(); bs < len(encrypted); bs, be = bs+block.BlockSize(), be+block.BlockSize() { + block.Decrypt(decrypted[bs:be], encrypted[bs:be]) + } + + trim := 0 + if len(decrypted) > 0 { + trim = len(decrypted) - int(decrypted[len(decrypted)-1]) + } + return decrypted[:trim], err +} + +func generateKey(key []byte) (genKey []byte) { + genKey = make([]byte, 16) + copy(genKey, key) + for i := 16; i < len(key); { + for j := 0; j < 16 && i < len(key); j, i = j+1, i+1 { + genKey[j] ^= key[i] + } + } + return genKey +} + +func AesEncryptCFB(origData []byte, key []byte) (encrypted []byte, err error) { + var block cipher.Block + block, err = aes.NewCipher(key) + if err != nil { + return nil, err + } + encrypted = make([]byte, aes.BlockSize+len(origData)) + iv := encrypted[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(encrypted[aes.BlockSize:], origData) + return encrypted, err +} + +func AesDecryptCFB(encrypted []byte, key []byte) (decrypted []byte, err error) { + var block cipher.Block + block, err = aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(encrypted) < aes.BlockSize { + return nil, errors.New("ciphertext too short") + } + iv := encrypted[:aes.BlockSize] + encrypted = encrypted[aes.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(encrypted, encrypted) + return encrypted, err +} diff --git a/pkg/misc/hint.go b/pkg/misc/hint.go index dd96a60..78afc54 100644 --- a/pkg/misc/hint.go +++ b/pkg/misc/hint.go @@ -26,6 +26,7 @@ import ( const ( XIDHint = "XID" GlobalLockHint = "GlobalLock" + UseDBHint = "UseDB" TraceParentHint = "TraceParent" ) @@ -49,6 +50,17 @@ func HasGlobalLockHint(hints []*ast.TableOptimizerHint) bool { return false } +func HasUseDBHint(hints []*ast.TableOptimizerHint) (bool, string) { + for _, hint := range hints { + if strings.EqualFold(hint.HintName.String(), UseDBHint) { + hintData := hint.HintData.(model.CIStr) + ds := hintData.String() + return true, ds + } + } + return false, "" +} + func HasTraceParentHint(hints []*ast.TableOptimizerHint) (bool, string) { for _, hint := range hints { if strings.EqualFold(hint.HintName.String(), TraceParentHint) { diff --git a/pkg/mysql/conn.go b/pkg/mysql/conn.go index 8d74249..d7ef61d 100644 --- a/pkg/mysql/conn.go +++ b/pkg/mysql/conn.go @@ -615,8 +615,14 @@ func (c *Conn) writeTextRow(row []*proto.Value) error { if val == nil || val.Val == nil { length++ } else { - l := len(val.Raw) - length += misc.LenEncIntSize(uint64(l)) + l + value, ok := val.Val.([]byte) + if ok { + l := len(value) + length += misc.LenEncIntSize(uint64(l)) + l + } else { + l := len(val.Raw) + length += misc.LenEncIntSize(uint64(l)) + l + } } } @@ -626,9 +632,16 @@ func (c *Conn) writeTextRow(row []*proto.Value) error { if val == nil || val.Val == nil { pos = misc.WriteByte(data, pos, constant.NullValue) } else { - l := len(val.Raw) - pos = misc.WriteLenEncInt(data, pos, uint64(l)) - pos += copy(data[pos:], val.Raw) + value, ok := val.Val.([]byte) + if ok { + l := len(value) + pos = misc.WriteLenEncInt(data, pos, uint64(l)) + pos += copy(data[pos:], value) + } else { + l := len(val.Raw) + pos = misc.WriteLenEncInt(data, pos, uint64(l)) + pos += copy(data[pos:], val.Raw) + } } } @@ -812,7 +825,7 @@ func (c *Conn) WriteBinaryRows(result *Result) error { return nil } -func (c *Conn) WriteRows(result *MergeResult) error { +func (c *Conn) WriteRows(result *DecodedResult) error { for _, row := range result.Rows { switch r := row.(type) { case *TextRow: diff --git a/pkg/mysql/result.go b/pkg/mysql/result.go index da29547..0d08d51 100644 --- a/pkg/mysql/result.go +++ b/pkg/mysql/result.go @@ -33,17 +33,17 @@ func (res *Result) RowsAffected() (uint64, error) { return res.AffectedRows, nil } -type MergeResult struct { +type DecodedResult struct { Fields []*Field AffectedRows uint64 InsertId uint64 Rows []proto.Row } -func (res *MergeResult) LastInsertId() (uint64, error) { +func (res *DecodedResult) LastInsertId() (uint64, error) { return res.InsertId, nil } -func (res *MergeResult) RowsAffected() (uint64, error) { +func (res *DecodedResult) RowsAffected() (uint64, error) { return res.AffectedRows, nil } diff --git a/pkg/packet/mysql.go b/pkg/packet/mysql.go index c478716..ea2e5f7 100644 --- a/pkg/packet/mysql.go +++ b/pkg/packet/mysql.go @@ -651,11 +651,12 @@ func TextVal2MySQL(v *proto.Value) ([]byte, error) { constant.FieldTypeMediumBLOB, constant.FieldTypeLongBLOB, constant.FieldTypeBLOB, constant.FieldTypeVarString, constant.FieldTypeString, constant.FieldTypeGeometry, constant.FieldTypeJSON, constant.FieldTypeBit, constant.FieldTypeEnum, constant.FieldTypeSet: - l := len(v.Raw) + val := v.Val.([]byte) + l := len(val) length := misc.LenEncIntSize(uint64(l)) + l out = make([]byte, length) pos = misc.WriteLenEncInt(out, pos, uint64(l)) - copy(out[pos:], v.Raw) + copy(out[pos:], val) default: out = make([]byte, len(v.Raw)) copy(out, v.Raw) @@ -848,7 +849,8 @@ func TextVal2MySQLLen(v *proto.Value) (int, error) { constant.FieldTypeMediumBLOB, constant.FieldTypeLongBLOB, constant.FieldTypeBLOB, constant.FieldTypeVarString, constant.FieldTypeString, constant.FieldTypeGeometry, constant.FieldTypeJSON, constant.FieldTypeBit, constant.FieldTypeEnum, constant.FieldTypeSet: - l := len(v.Raw) + val := v.Val.([]byte) + l := len(val) length = misc.LenEncIntSize(uint64(l)) + l default: length = len(v.Raw) diff --git a/pkg/plan/result.go b/pkg/plan/result.go index 4a4523b..32fa4d6 100644 --- a/pkg/plan/result.go +++ b/pkg/plan/result.go @@ -172,7 +172,7 @@ func (c OrderByCells) Swap(i, j int) { func mergeResult(ctx context.Context, results []*ResultWithErr, orderBy *ast.OrderByClause, - limit *Limit) (*mysql.MergeResult, uint16) { + limit *Limit) (*mysql.DecodedResult, uint16) { if orderBy == nil && limit == nil { return mergeResultWithoutOrderByAndLimit(ctx, results) } @@ -190,7 +190,7 @@ func mergeResult(ctx context.Context, // mergeResultWithOutOrderByAndLimit e.g. select * from t where id between ? and ? func mergeResultWithoutOrderByAndLimit(ctx context.Context, - results []*ResultWithErr) (*mysql.MergeResult, uint16) { + results []*ResultWithErr) (*mysql.DecodedResult, uint16) { var ( fields []*mysql.Field warning uint16 = 0 @@ -234,7 +234,7 @@ func mergeResultWithoutOrderByAndLimit(ctx context.Context, } } fields = results[0].Result.(*mysql.Result).Fields - result := &mysql.MergeResult{ + result := &mysql.DecodedResult{ Fields: fields, AffectedRows: 0, InsertId: 0, @@ -246,7 +246,7 @@ func mergeResultWithoutOrderByAndLimit(ctx context.Context, // mergeResultWithLimit e.g. select * from t where id between ? and ? limit ?,? func mergeResultWithLimit(ctx context.Context, results []*ResultWithErr, - limit *Limit) (*mysql.MergeResult, uint16) { + limit *Limit) (*mysql.DecodedResult, uint16) { var ( fields []*mysql.Field warning uint16 = 0 @@ -303,7 +303,7 @@ func mergeResultWithLimit(ctx context.Context, } } fields = results[0].Result.(*mysql.Result).Fields - result := &mysql.MergeResult{ + result := &mysql.DecodedResult{ Fields: fields, AffectedRows: 0, InsertId: 0, @@ -316,7 +316,7 @@ func mergeResultWithLimit(ctx context.Context, func mergeResultWithOrderByAndLimit(ctx context.Context, results []*ResultWithErr, orderBy *ast.OrderByClause, - limit *Limit) (*mysql.MergeResult, uint16) { + limit *Limit) (*mysql.DecodedResult, uint16) { var ( fields []*mysql.Field orderByFields []*OrderField @@ -407,7 +407,7 @@ func mergeResultWithOrderByAndLimit(ctx context.Context, for _, rlt := range results { warning += rlt.Warning } - result := &mysql.MergeResult{ + result := &mysql.DecodedResult{ Fields: fields, AffectedRows: 0, InsertId: 0, @@ -419,7 +419,7 @@ func mergeResultWithOrderByAndLimit(ctx context.Context, // mergeResultWithOrderBy e.g. select * from t where id between ? and ? order by id desc func mergeResultWithOrderBy(ctx context.Context, results []*ResultWithErr, - orderBy *ast.OrderByClause) (*mysql.MergeResult, uint16) { + orderBy *ast.OrderByClause) (*mysql.DecodedResult, uint16) { var ( fields []*mysql.Field orderByFields []*OrderField @@ -499,7 +499,7 @@ func mergeResultWithOrderBy(ctx context.Context, for _, rlt := range results { warning += rlt.Warning } - result := &mysql.MergeResult{ + result := &mysql.DecodedResult{ Fields: fields, AffectedRows: 0, InsertId: 0, @@ -508,7 +508,7 @@ func mergeResultWithOrderBy(ctx context.Context, return result, warning } -func aggregateResult(ctx context.Context, result *mysql.MergeResult) { +func aggregateResult(ctx context.Context, result *mysql.DecodedResult) { sqlText := proto.SqlText(ctx) funcColumns := proto.Variable(ctx, FuncColumns) if funcColumns == nil { diff --git a/pkg/plan/result_test.go b/pkg/plan/result_test.go index 7effbfd..d5a09d7 100644 --- a/pkg/plan/result_test.go +++ b/pkg/plan/result_test.go @@ -65,7 +65,7 @@ func TestMergeResultWithOutOrderByAndLimit(t *testing.T) { defer patch2.Reset() patch3 := buildBinaryRowDecodePatch() defer patch3.Reset() - merge := func(commandType int) (*mysql.MergeResult, uint16) { + merge := func(commandType int) (*mysql.DecodedResult, uint16) { return mergeResultWithoutOrderByAndLimit(proto.WithCommandType(context.Background(), byte(commandType)), []*ResultWithErr{ { @@ -106,7 +106,7 @@ func TestMergeResultWithLimit(t *testing.T) { defer patch2.Reset() patch3 := buildBinaryRowDecodePatch() defer patch3.Reset() - merge := func(commandType int) (*mysql.MergeResult, uint16) { + merge := func(commandType int) (*mysql.DecodedResult, uint16) { return mergeResultWithLimit(proto.WithCommandType(context.Background(), byte(commandType)), []*ResultWithErr{ { @@ -150,7 +150,7 @@ func TestMergeResultWithOrderByAndLimit(t *testing.T) { defer patch2.Reset() patch3 := buildBinaryRowDecodePatch() defer patch3.Reset() - merge := func(commandType int) (*mysql.MergeResult, uint16) { + merge := func(commandType int) (*mysql.DecodedResult, uint16) { return mergeResultWithOrderByAndLimit(proto.WithCommandType(context.Background(), byte(commandType)), []*ResultWithErr{ { @@ -218,7 +218,7 @@ func TestMergeResultWithOrderBy(t *testing.T) { defer patch2.Reset() patch3 := buildBinaryRowDecodePatch() defer patch3.Reset() - merge := func(commandType int) (*mysql.MergeResult, uint16) { + merge := func(commandType int) (*mysql.DecodedResult, uint16) { return mergeResultWithOrderBy(proto.WithCommandType(context.Background(), byte(commandType)), []*ResultWithErr{ { diff --git a/test/rws/read_write_splitting_test.go b/test/rws/read_write_splitting_test.go index 2907c53..d3103e8 100644 --- a/test/rws/read_write_splitting_test.go +++ b/test/rws/read_write_splitting_test.go @@ -34,8 +34,12 @@ const ( insertEmployee = `INSERT INTO employees ( emp_no, birth_date, first_name, last_name, gender, hire_date ) VALUES (?, ?, ?, ?, ?, ?)` selectEmployee1 = `SELECT emp_no, birth_date, first_name, last_name, gender, hire_date FROM employees WHERE emp_no = ?` selectEmployee2 = `SELECT /*+ UseDB('employees-master') */ emp_no, birth_date, first_name, last_name, gender, hire_date FROM employees WHERE emp_no = ?` - updateEmployee = `UPDATE employees set last_name = ? where emp_no = ?` + updateEmployee = `UPDATE employees SET last_name = ? WHERE emp_no = ?` deleteEmployee = `DELETE FROM employees WHERE emp_no = ?` + + insertDepartment = `INSERT INTO departments( id, dept_no, dept_name ) values (?, ?, ?)` + updateDepartment = `UPDATE departments SET dept_name = ? WHERE id = ?` + selectDepartment = `SELECT /*+ UseDB('employees-master') */ id, dept_name FROM departments WHERE id = ?` ) type _ReadWriteSplittingSuite struct { @@ -139,6 +143,29 @@ func (suite *_ReadWriteSplittingSuite) TestSelect1() { } } +func (suite *_ReadWriteSplittingSuite) TestInsertEncryption() { + result, err := suite.db.Exec(insertDepartment, 1, "1001", "sunset") + if suite.NoErrorf(err, "insert row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "insert row error: %v", err) { + suite.Equal(int64(1), affected) + } + } + + rows, err := suite.db.Query(selectDepartment, 1) + if suite.NoErrorf(err, "select row error: %v", err) { + var ( + id int64 + deptName string + ) + for rows.Next() { + err := rows.Scan(&id, &deptName) + suite.NoError(err) + suite.T().Logf("id: %d, dept name: %s", id, deptName) + } + } +} + func (suite *_ReadWriteSplittingSuite) TestSelect2() { rows, err := suite.db.Query(selectEmployee2, 100001) if suite.NoErrorf(err, "select row error: %v", err) { @@ -180,5 +207,28 @@ func (suite *_ReadWriteSplittingSuite) TestUpdate() { } } +func (suite *_ReadWriteSplittingSuite) TestUpdateEncryption() { + result, err := suite.db.Exec(updateDepartment, "moonlight", 1) + if suite.NoErrorf(err, "update department error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "update department error: %v", err) { + suite.Equal(int64(1), affected) + } + } + + rows, err := suite.db.Query(selectDepartment, 1) + if suite.NoErrorf(err, "select row error: %v", err) { + var ( + id int64 + deptName string + ) + for rows.Next() { + err := rows.Scan(&id, &deptName) + suite.NoError(err) + suite.T().Logf("id: %d, dept name: %s", id, deptName) + } + } +} + func (suite *_ReadWriteSplittingSuite) TearDownSuite() { } diff --git a/test/sdb/crud_test.go b/test/sdb/crud_test.go index e197456..fa47aa8 100644 --- a/test/sdb/crud_test.go +++ b/test/sdb/crud_test.go @@ -29,11 +29,15 @@ const ( driverName = "mysql" // user:password@tcp(127.0.0.1:3306)/dbName? - dataSourceName = "dksl:123456@tcp(127.0.0.1:13306)/employees?timeout=10s&readTimeout=10s&writeTimeout=10s&parseTime=true&loc=Local&charset=utf8mb4,utf8" + dataSourceName = "dksl:123456@tcp(127.0.0.1:13306)/employees?interpolateParams=true&timeout=10s&readTimeout=10s&writeTimeout=10s&parseTime=true&loc=Local&charset=utf8mb4,utf8" insertEmployee = `INSERT INTO employees ( emp_no, birth_date, first_name, last_name, gender, hire_date ) VALUES (?, ?, ?, ?, ?, ?)` selectEmployee = `SELECT emp_no, birth_date, first_name, last_name, gender, hire_date FROM employees WHERE emp_no = ?` - updateEmployee = `UPDATE employees set last_name = ? where emp_no = ?` + updateEmployee = `UPDATE employees SET last_name = ? WHERE emp_no = ?` deleteEmployee = `DELETE FROM employees WHERE emp_no = ?` + + insertDepartment = `INSERT INTO departments( id, dept_no, dept_name ) values (?, ?, ?)` + updateDepartment = `UPDATE departments SET dept_name = ? WHERE id = ?` + selectDepartment = `SELECT id, dept_name FROM departments WHERE id = ?` ) type _CRUDSuite struct { @@ -80,6 +84,29 @@ func (suite *_CRUDSuite) TestInsert() { } } +func (suite *_CRUDSuite) TestInsertEncryption() { + result, err := suite.db.Exec(insertDepartment, 1, "1001", "sunset") + if suite.NoErrorf(err, "insert row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "insert row error: %v", err) { + suite.Equal(int64(1), affected) + } + } + + rows, err := suite.db.Query(selectDepartment, 1) + if suite.NoErrorf(err, "select row error: %v", err) { + var ( + id int64 + deptName string + ) + for rows.Next() { + err := rows.Scan(&id, &deptName) + suite.NoError(err) + suite.T().Logf("id: %d, dept name: %s", id, deptName) + } + } +} + func (suite *_CRUDSuite) TestSelect() { rows, err := suite.db.Query(selectEmployee, 100001) if suite.NoErrorf(err, "select row error: %v", err) { @@ -107,6 +134,29 @@ func (suite *_CRUDSuite) TestUpdate() { } } +func (suite *_CRUDSuite) TestUpdateEncryption() { + result, err := suite.db.Exec(updateDepartment, "moonlight", 1) + if suite.NoErrorf(err, "update department error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "update department error: %v", err) { + suite.Equal(int64(1), affected) + } + } + + rows, err := suite.db.Query(selectDepartment, 1) + if suite.NoErrorf(err, "select row error: %v", err) { + var ( + id int64 + deptName string + ) + for rows.Next() { + err := rows.Scan(&id, &deptName) + suite.NoError(err) + suite.T().Logf("id: %d, dept name: %s", id, deptName) + } + } +} + func (suite *_CRUDSuite) TearDownSuite() { result, err := suite.db.Exec(deleteEmployee, 100001) if suite.NoErrorf(err, "delete row error: %v", err) {