Skip to content

Commit

Permalink
importinto/lightning: fix negative value cast to upper bound in non s…
Browse files Browse the repository at this point in the history
…trict sql mode (#58641)

close #58613
  • Loading branch information
D3Hunter authored Jan 6, 2025
1 parent cf208cf commit 4f78f12
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 12 deletions.
1 change: 1 addition & 0 deletions pkg/lightning/backend/kv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ go_library(
"//pkg/table/tblctx",
"//pkg/tablecodec",
"//pkg/types",
"//pkg/util",
"//pkg/util/chunk",
"//pkg/util/codec",
"//pkg/util/context",
Expand Down
7 changes: 2 additions & 5 deletions pkg/lightning/backend/kv/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/table/tblctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/timeutil"
Expand All @@ -48,11 +49,7 @@ type litExprContext struct {

// NewExpressionContext creates a new `*ExprContext` for lightning import.
func newLitExprContext(sqlMode mysql.SQLMode, sysVars map[string]string, timestamp int64) (*litExprContext, error) {
flags := types.DefaultStmtFlags.
WithTruncateAsWarning(!sqlMode.HasStrictMode()).
WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode() ||
!sqlMode.HasNoZeroInDateMode() || !sqlMode.HasNoZeroDateMode())
flags := util.GetTypeFlagsForImportInto(types.DefaultStmtFlags, sqlMode)

errLevels := stmtctx.DefaultStmtErrLevels
errLevels[errctx.ErrGroupTruncate] = errctx.ResolveErrLevel(flags.IgnoreTruncateErr(), flags.TruncateAsWarning())
Expand Down
13 changes: 7 additions & 6 deletions pkg/lightning/backend/kv/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
)

func TestLitExprContext(t *testing.T) {
baseFlags := types.DefaultStmtFlags &^ types.FlagAllowNegativeToUnsigned
cases := []struct {
sqlMode mysql.SQLMode
sysVars map[string]string
Expand All @@ -47,7 +48,7 @@ func TestLitExprContext(t *testing.T) {
{
sqlMode: mysql.ModeNone,
timestamp: 1234567,
checkFlags: types.DefaultStmtFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr,
checkFlags: baseFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr,
checkErrLevel: func() errctx.LevelMap {
m := stmtctx.DefaultStmtErrLevels
m[errctx.ErrGroupTruncate] = errctx.LevelWarn
Expand All @@ -68,7 +69,7 @@ func TestLitExprContext(t *testing.T) {
{
sqlMode: mysql.ModeStrictTransTables | mysql.ModeNoZeroDate | mysql.ModeNoZeroInDate |
mysql.ModeErrorForDivisionByZero,
checkFlags: types.DefaultStmtFlags,
checkFlags: baseFlags,
checkErrLevel: func() errctx.LevelMap {
m := stmtctx.DefaultStmtErrLevels
m[errctx.ErrGroupTruncate] = errctx.LevelError
Expand All @@ -80,7 +81,7 @@ func TestLitExprContext(t *testing.T) {
},
{
sqlMode: mysql.ModeNoZeroDate | mysql.ModeNoZeroInDate | mysql.ModeErrorForDivisionByZero,
checkFlags: types.DefaultStmtFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr,
checkFlags: baseFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr,
checkErrLevel: func() errctx.LevelMap {
m := stmtctx.DefaultStmtErrLevels
m[errctx.ErrGroupTruncate] = errctx.LevelWarn
Expand All @@ -92,7 +93,7 @@ func TestLitExprContext(t *testing.T) {
},
{
sqlMode: mysql.ModeStrictTransTables | mysql.ModeNoZeroInDate,
checkFlags: types.DefaultStmtFlags | types.FlagIgnoreZeroInDateErr,
checkFlags: baseFlags | types.FlagIgnoreZeroInDateErr,
checkErrLevel: func() errctx.LevelMap {
m := stmtctx.DefaultStmtErrLevels
m[errctx.ErrGroupTruncate] = errctx.LevelError
Expand All @@ -104,7 +105,7 @@ func TestLitExprContext(t *testing.T) {
},
{
sqlMode: mysql.ModeStrictTransTables | mysql.ModeNoZeroDate,
checkFlags: types.DefaultStmtFlags | types.FlagIgnoreZeroInDateErr,
checkFlags: baseFlags | types.FlagIgnoreZeroInDateErr,
checkErrLevel: func() errctx.LevelMap {
m := stmtctx.DefaultStmtErrLevels
m[errctx.ErrGroupTruncate] = errctx.LevelError
Expand All @@ -116,7 +117,7 @@ func TestLitExprContext(t *testing.T) {
},
{
sqlMode: mysql.ModeStrictTransTables | mysql.ModeAllowInvalidDates,
checkFlags: types.DefaultStmtFlags | types.FlagIgnoreZeroInDateErr | types.FlagIgnoreInvalidDateErr,
checkFlags: baseFlags | types.FlagIgnoreZeroInDateErr | types.FlagIgnoreInvalidDateErr,
checkErrLevel: func() errctx.LevelMap {
m := stmtctx.DefaultStmtErrLevels
m[errctx.ErrGroupTruncate] = errctx.LevelError
Expand Down
4 changes: 3 additions & 1 deletion pkg/util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,12 +698,14 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro
// GetTypeFlagsForInsert gets the type flags for insert statement.
func GetTypeFlagsForInsert(baseFlags types.Flags, sqlMode mysql.SQLMode, ignoreErr bool) types.Flags {
strictSQLMode := sqlMode.HasStrictMode()
// see comments in ResetContextOfStmt for WithAllowNegativeToUnsigned part.
return baseFlags.
WithTruncateAsWarning(!strictSQLMode || ignoreErr).
WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!sqlMode.HasNoZeroInDateMode() ||
!sqlMode.HasNoZeroDateMode() || !strictSQLMode || ignoreErr ||
sqlMode.HasAllowInvalidDatesMode())
sqlMode.HasAllowInvalidDatesMode()).
WithAllowNegativeToUnsigned(false)
}

// GetTypeFlagsForImportInto gets the type flags for import into statement which
Expand Down
9 changes: 9 additions & 0 deletions tests/realtikvtest/importintotest2/from_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,12 @@ func (s *mockGCSSuite) TestImportFromSelectStaleRead() {
s.tk.MustExec("import into dst from " + staleReadSQL)
s.tk.MustQuery("select * from dst").Check(testkit.Rows("1 a", "2 b"))
}

func (s *mockGCSSuite) TestCastNegativeToUnsigned() {
s.prepareAndUseDB("from_select")
s.tk.MustExec("create table dt(id int unsigned)")
s.ErrorContains(s.tk.ExecToErr("import into dt from select -1"), "constant -1 overflows int")
s.tk.MustExec("set sql_mode=''")
s.tk.MustExec("import into dt from select -1")
s.tk.MustQuery("select * from dt").Check(testkit.Rows("0"))
}

0 comments on commit 4f78f12

Please sign in to comment.