From 924330d0a5a23c3f560db730b682d2cfb4fd5bad Mon Sep 17 00:00:00 2001 From: Kirill Date: Sat, 14 Sep 2024 17:46:35 +0300 Subject: [PATCH] Support COPY (#748) Parse COPY command and extract table and column names and options. Send the COPY query to all shards. Receive all CopyData and parse them. Check if they correspond to one shard. Send CopyData to that one shard. Then send CopyDone to all shards and receive responses. --- go.mod | 2 +- go.sum | 4 +- router/frontend/frontend_test.go | 46 ++++++--- router/qrouter/proxy_routing.go | 19 +--- router/qrouter/proxy_routing_test.go | 21 +--- router/relay/relay.go | 95 +++++++++++++++---- router/server/multishard.go | 31 ++++-- .../tests/router/expected/copy_routing.out | 36 +++++-- .../regress/tests/router/sql/copy_routing.sql | 19 ++-- 9 files changed, 181 insertions(+), 92 deletions(-) diff --git a/go.mod b/go.mod index 105e64417..d7405820e 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/lib/pq v1.10.9 github.com/libp2p/go-reuseport v0.4.0 github.com/opentracing/opentracing-go v1.2.0 - github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1 + github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 github.com/sevlyar/go-daemon v0.1.6 diff --git a/go.sum b/go.sum index 8bc4d86f1..6329f36c4 100644 --- a/go.sum +++ b/go.sum @@ -164,8 +164,8 @@ github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrB github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= -github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1 h1:AwlQkwnrqRyL8lqZTTAzfQ09niEc+6oFiDvQkMImTPE= -github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1/go.mod h1:2dPBQAhqv/30mhzj2yBXQkXhsGJQ8GhM+oWOfbGua58= +github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c h1:4sXBG7ZDtG/rN2jqgmzsMawfcTKQvTCTTo8iQ7eR6VU= +github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c/go.mod h1:2dPBQAhqv/30mhzj2yBXQkXhsGJQ8GhM+oWOfbGua58= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/router/frontend/frontend_test.go b/router/frontend/frontend_test.go index cdb8daa1a..4b2de149b 100644 --- a/router/frontend/frontend_test.go +++ b/router/frontend/frontend_test.go @@ -329,6 +329,15 @@ func TestFrontendSimpleCopyIn(t *testing.T) { qr := mockqr.NewMockQueryRouter(ctrl) cmngr := mockcmgr.NewMockPoolMgr(ctrl) + sh1 := mocksh.NewMockShard(ctrl) + sh1.EXPECT().Name().AnyTimes().Return("sh1") + sh1.EXPECT().SHKey().AnyTimes().Return(kr.ShardKey{Name: "sh1"}) + sh1.EXPECT().ID().AnyTimes().Return(uint(1)) + sh2 := mocksh.NewMockShard(ctrl) + sh2.EXPECT().Name().AnyTimes().Return("sh2") + sh2.EXPECT().SHKey().AnyTimes().Return(kr.ShardKey{Name: "sh2"}) + sh2.EXPECT().ID().AnyTimes().Return(uint(2)) + frrule := &config.FrontendRule{ DB: "db1", Usr: "user1", @@ -337,7 +346,7 @@ func TestFrontendSimpleCopyIn(t *testing.T) { beRule := &config.BackendRule{} srv.EXPECT().Name().AnyTimes().Return("serv1") - srv.EXPECT().Datashards().AnyTimes().Return([]shard.Shard{}) + srv.EXPECT().Datashards().AnyTimes().Return([]shard.Shard{sh1, sh2}) cl.EXPECT().Server().AnyTimes().Return(srv) cl.EXPECT().MaintainParams().AnyTimes().Return(false) @@ -375,22 +384,29 @@ func TestFrontendSimpleCopyIn(t *testing.T) { cmngr.EXPECT().TXEndCB(gomock.Any()).AnyTimes() + tableref := &lyx.RangeVar{ + RelationName: "xx", + } + qr.EXPECT().Route(gomock.Any(), &lyx.Copy{ - TableRef: &lyx.RangeVar{ - RelationName: "xx", - }, - Where: &lyx.AExprEmpty{}, - IsFrom: true, - }, gomock.Any()).Return(routingstate.ShardMatchState{ - Route: &routingstate.DataShardRoute{ - Shkey: kr.ShardKey{ - Name: "sh1", - }, - }, - }, nil).Times(1) + TableRef: tableref, + Where: &lyx.AExprEmpty{}, + IsFrom: true, + }, gomock.Any()).Return(routingstate.MultiMatchState{}, nil).Times(1) + + qr.EXPECT().Route(gomock.Any(), &lyx.Insert{ + TableRef: tableref, + SubSelect: &lyx.ValueClause{Values: []lyx.Node{&lyx.AExprSConst{Value: "1"}}}, + }, cl).Times(4).Return(routingstate.ShardMatchState{Route: &routingstate.DataShardRoute{Shkey: sh1.SHKey()}}, nil) + + qr.EXPECT().DataShardsRoutes().AnyTimes().Return([]*routingstate.DataShardRoute{ + &routingstate.DataShardRoute{Shkey: sh1.SHKey()}, + &routingstate.DataShardRoute{Shkey: sh2.SHKey()}}, + ) route := route.NewRoute(beRule, frrule, map[string]*config.Shard{ "sh1": {}, + "sh2": {}, }) cl.EXPECT().Route().AnyTimes().Return(route) @@ -401,12 +417,12 @@ func TestFrontendSimpleCopyIn(t *testing.T) { cl.EXPECT().Receive().Times(1).Return(query, nil) - cl.EXPECT().Receive().Times(4).Return(&pgproto3.CopyData{}, nil) + cl.EXPECT().Receive().Times(4).Return(&pgproto3.CopyData{Data: []byte("1\n")}, nil) cl.EXPECT().Receive().Times(1).Return(&pgproto3.CopyDone{}, nil) srv.EXPECT().Send(query).Times(1).Return(nil) - srv.EXPECT().Send(&pgproto3.CopyData{}).Times(4).Return(nil) + sh1.EXPECT().Send(&pgproto3.CopyData{Data: []byte("1\n")}).Times(4).Return(nil) srv.EXPECT().Send(&pgproto3.CopyDone{}).Times(1).Return(nil) srv.EXPECT().Receive().Times(1).Return(&pgproto3.CopyInResponse{}, nil) diff --git a/router/qrouter/proxy_routing.go b/router/qrouter/proxy_routing.go index b1d630fc7..9a2b0cd9c 100644 --- a/router/qrouter/proxy_routing.go +++ b/router/qrouter/proxy_routing.go @@ -667,21 +667,6 @@ func (qr *ProxyQrouter) deparseShardingMapping( _ = qr.deparseFromNode(stmt.TableRef, meta) - return qr.routeByClause(ctx, clause, meta) - case *lyx.Copy: - if !stmt.IsFrom { - return fmt.Errorf("copy from stdin is not implemented") - } - - _ = qr.deparseFromNode(stmt.TableRef, meta) - - clause := stmt.Where - - if clause == nil { - // will not work - return nil - } - return qr.routeByClause(ctx, clause, meta) } @@ -1001,13 +986,15 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s return routingstate.RandomMatchState{}, nil } - case *lyx.Delete, *lyx.Update, *lyx.Copy: + case *lyx.Delete, *lyx.Update: // UPDATE and/or DELETE, COPY stmts, which // would be routed with their WHERE clause err := qr.deparseShardingMapping(ctx, stmt, meta) if err != nil { return nil, err } + case *lyx.Copy: + return routingstate.MultiMatchState{}, nil default: spqrlog.Zero.Debug().Interface("statement", stmt).Msg("proxy-routing message to all shards") } diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index 53c3a1f2f..7865bf672 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -1426,25 +1426,8 @@ func TestCopySingleShard(t *testing.T) { for _, tt := range []tcase{ { query: "COPY xx FROM STDIN WHERE i = 1;", - exp: routingstate.ShardMatchState{ - Route: &routingstate.DataShardRoute{ - Shkey: kr.ShardKey{ - Name: "sh1", - }, - Matchedkr: &kr.KeyRange{ - ShardID: "sh1", - ID: "id1", - Distribution: distribution, - LowerBound: []interface{}{ - int64(1), - }, - - ColumnTypes: []string{qdb.ColumnTypeInteger}, - }, - }, - TargetSessionAttrs: "any", - }, - err: nil, + exp: routingstate.MultiMatchState{}, + err: nil, }, } { parserRes, err := lyx.Parse(tt.query) diff --git a/router/relay/relay.go b/router/relay/relay.go index bf7b9a2f3..396991a3c 100644 --- a/router/relay/relay.go +++ b/router/relay/relay.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "strings" "time" "github.com/pg-sharding/lyx/lyx" @@ -67,7 +68,7 @@ type RelayStateMgr interface { RelayRunCommand(msg pgproto3.FrontendMessage, waitForResp bool, replyCl bool) error ProcQuery(query pgproto3.FrontendMessage, waitForResp bool, replyCl bool) (txstatus.TXStatus, []pgproto3.BackendMessage, bool, error) - ProcCopy(query pgproto3.FrontendMessage) error + ProcCopy(stmt *lyx.Copy, data *pgproto3.CopyData, expRoute *routingstate.DataShardRoute) error ProcCommand(query pgproto3.FrontendMessage, waitForResp bool, replyCl bool) error @@ -656,15 +657,72 @@ func (rst *RelayStateImpl) RelayRunCommand(msg pgproto3.FrontendMessage, waitFor } // TODO : unit tests -func (rst *RelayStateImpl) ProcCopy(query pgproto3.FrontendMessage) error { +func (rst *RelayStateImpl) ProcCopy(stmt *lyx.Copy, data *pgproto3.CopyData, expRoute *routingstate.DataShardRoute) error { spqrlog.Zero.Debug(). Uint("client", rst.Client().ID()). - Type("query-type", query). Msg("client process copy") - _ = rst.Client().ReplyDebugNotice(fmt.Sprintf("executing your query %v", query)) // TODO perfomance issue + _ = rst.Client().ReplyDebugNotice(fmt.Sprintf("executing your query %v", data)) // TODO perfomance issue rst.Client().RLock() defer rst.Client().RUnlock() - return rst.Client().Server().Send(query) + + // Read delimiter from COPY options + delimiter := byte('\t') + for _, opt := range stmt.Options { + if o := opt.(*lyx.Option); strings.ToLower(o.Name) == "delimiter" { + delimiter = o.Arg.(*lyx.AExprSConst).Value[0] + } + } + + // Parse data + // and decide where to route + prevDelimiter := 0 + prevLine := 0 + valueClause := &lyx.ValueClause{} + for i, b := range data.Data { + if i+2 < len(data.Data) && string(data.Data[i:i+2]) == "\\." { + prevLine = len(data.Data) + break + } + if b == '\n' || b == delimiter { + valueClause.Values = append(valueClause.Values, &lyx.AExprSConst{Value: string(data.Data[prevDelimiter:i])}) + prevDelimiter = i + 1 + } + if b != '\n' { + continue + } + + // check where this tuple should go + r, err := rst.QueryRouter().Route(context.TODO(), &lyx.Insert{TableRef: stmt.TableRef, Columns: stmt.Columns, SubSelect: valueClause}, rst.Cl) + if err != nil { + return err + } + + smt, ok := r.(routingstate.ShardMatchState) + if !ok { + return fmt.Errorf("multishard copy is not supported") + } + + if expRoute.Shkey.Name == "" { + *expRoute = *smt.Route + } + if smt.Route.Shkey.Name != expRoute.Shkey.Name { + return fmt.Errorf("multishard copy is not supported") + } + + valueClause = &lyx.ValueClause{} + prevLine = i + 1 + } + + for _, sh := range rst.Client().Server().Datashards() { + if expRoute != nil && sh.Name() == expRoute.Shkey.Name { + err := sh.Send(&pgproto3.CopyData{Data: data.Data[:prevLine]}) + data.Data = data.Data[prevLine:] + return err + } + } + + // shouldn't exit from here + return nil } // TODO : unit tests @@ -680,16 +738,16 @@ func (rst *RelayStateImpl) ProcCopyComplete(query *pgproto3.FrontendMessage) err } for { - if msg, err := rst.Client().Server().Receive(); err != nil { + msg, err := rst.Client().Server().Receive() + if err != nil { return err - } else { - switch msg.(type) { - case *pgproto3.CommandComplete, *pgproto3.ErrorResponse: - return rst.Client().Send(msg) - default: - if err := rst.Client().Send(msg); err != nil { - return err - } + } + switch msg.(type) { + case *pgproto3.CommandComplete, *pgproto3.ErrorResponse: + return rst.Client().Send(msg) + default: + if err := rst.Client().Send(msg); err != nil { + return err } } } @@ -748,16 +806,21 @@ func (rst *RelayStateImpl) ProcQuery(query pgproto3.FrontendMessage, waitForResp return txstatus.TXERR, nil, false, err } + q := rst.qp.Stmt().(*lyx.Copy) + if err := func() error { + msg := &pgproto3.CopyData{Data: make([]byte, 0)} + route := &routingstate.DataShardRoute{} for { cpMsg, err := rst.Client().Receive() if err != nil { return err } - switch cpMsg.(type) { + switch newMsg := cpMsg.(type) { case *pgproto3.CopyData: - if err := rst.ProcCopy(cpMsg); err != nil { + msg.Data = append(msg.Data, newMsg.Data...) + if err = rst.ProcCopy(q, msg, route); err != nil { return err } case *pgproto3.CopyDone, *pgproto3.CopyFail: diff --git a/router/server/multishard.go b/router/server/multishard.go index 37e952df9..47dc69ab8 100644 --- a/router/server/multishard.go +++ b/router/server/multishard.go @@ -31,7 +31,8 @@ const ( RunningState ServerErrorState CommandCompleteState - CopyState + CopyOutState + CopyInState ) type MultiShardServer struct { @@ -178,7 +179,8 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { var saveRd *pgproto3.RowDescription = nil var saveCC *pgproto3.CommandComplete = nil var saveRFQ *pgproto3.ReadyForQuery = nil - /* Step one: ensure all shard backend are stared */ + var saveCIn *pgproto3.CopyInResponse = nil + /* Step one: ensure all shard backend are started */ for i := range m.activeShards { for { // all shards should be in rfq state @@ -200,12 +202,19 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { switch retMsg := msg.(type) { case *pgproto3.CopyOutResponse: - if m.multistate != InitialState && m.multistate != CopyState { + if m.multistate != InitialState && m.multistate != CopyOutState { return nil, MultiShardSyncBroken } m.states[i] = ShardCopyState - m.multistate = CopyState + m.multistate = CopyOutState m.copyBuf = append(m.copyBuf, retMsg) + case *pgproto3.CopyInResponse: + if m.multistate != InitialState && m.multistate != CopyInState { + return nil, MultiShardSyncBroken + } + m.states[i] = ShardCopyState + m.multistate = CopyInState + saveCIn = retMsg case *pgproto3.CommandComplete: m.states[i] = ShardCCState saveCC = retMsg // @@ -257,7 +266,7 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { m.multistate = InitialState return saveRFQ, nil } - if m.multistate == CopyState { + if m.multistate == CopyOutState { n := len(m.copyBuf) var currMsg *pgproto3.CopyOutResponse m.copyBuf, currMsg = m.copyBuf[n-2:], m.copyBuf[n-1] @@ -267,10 +276,14 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { Msg("miltishard server: flush copy buff") return currMsg, nil } + if m.multistate == CopyInState { + m.multistate = RunningState + return saveCIn, nil + } m.multistate = RunningState return saveRd, nil - case CopyState: + case CopyOutState: if len(m.copyBuf) > 0 { spqrlog.Zero.Debug().Msg("miltishard server: flush copy buff") n := len(m.copyBuf) @@ -302,7 +315,7 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { spqrlog.Zero.Info(). Uint("shard", m.activeShards[i].ID()). Type("message-type", msg). - Msg("multishard server: recived message from shard") + Msg("multishard server: received message from shard") switch msg.(type) { case *pgproto3.CommandComplete: @@ -322,6 +335,10 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) { return &pgproto3.CommandComplete{ CommandTag: []byte{}, // XXX : fix this }, nil + case CopyInState: + return &pgproto3.CommandComplete{ + CommandTag: []byte{}, + }, nil case RunningState: /* Step two: fetch all datarow ms gs */ for i := range m.activeShards { diff --git a/test/regress/tests/router/expected/copy_routing.out b/test/regress/tests/router/expected/copy_routing.out index 89e5df1b7..114a33580 100644 --- a/test/regress/tests/router/expected/copy_routing.out +++ b/test/regress/tests/router/expected/copy_routing.out @@ -34,8 +34,8 @@ ALTER DISTRIBUTION ds1 ATTACH RELATION copy_test DISTRIBUTION KEY id; \c regress CREATE TABLE copy_test (id int); NOTICE: send query to shard(s) : sh1,sh2 -COPY copy_test FROM STDIN WHERE id <= 10; -NOTICE: send query to shard(s) : sh1 +COPY copy_test(id) FROM STDIN WHERE id <= 10; +NOTICE: send query to shard(s) : sh1,sh2 SELECT * FROM copy_test WHERE id <= 10; NOTICE: send query to shard(s) : sh1 id @@ -47,10 +47,24 @@ NOTICE: send query to shard(s) : sh1 5 (5 rows) -COPY copy_test FROM STDIN WHERE id <= 30; -NOTICE: send query to shard(s) : sh2 -SELECT * FROM copy_test WHERE id <= 30 ORDER BY copy_test; -NOTICE: send query to shard(s) : sh2 +COPY copy_test(id) FROM STDIN; +NOTICE: send query to shard(s) : sh1,sh2 +ERROR: client processing error: multishard copy is not supported, tx status IDLE +SELECT * FROM copy_test; +NOTICE: send query to shard(s) : sh1,sh2 + id +---- + 1 + 2 + 3 + 4 + 5 +(5 rows) + +COPY copy_test(id) FROM STDIN; +NOTICE: send query to shard(s) : sh1,sh2 +SELECT * FROM copy_test; +NOTICE: send query to shard(s) : sh1,sh2 id ---- 1 @@ -58,10 +72,12 @@ NOTICE: send query to shard(s) : sh2 3 4 5 - 12 - 22 - 23 -(8 rows) + 41 + 42 + 43 + 44 + 45 +(10 rows) DROP TABLE copy_test; NOTICE: send query to shard(s) : sh1,sh2 diff --git a/test/regress/tests/router/sql/copy_routing.sql b/test/regress/tests/router/sql/copy_routing.sql index 8e4010484..8b8ef16dc 100644 --- a/test/regress/tests/router/sql/copy_routing.sql +++ b/test/regress/tests/router/sql/copy_routing.sql @@ -7,20 +7,17 @@ ALTER DISTRIBUTION ds1 ATTACH RELATION copy_test DISTRIBUTION KEY id; \c regress CREATE TABLE copy_test (id int); -COPY copy_test FROM STDIN WHERE id <= 10; +COPY copy_test(id) FROM STDIN WHERE id <= 10; 1 2 3 4 5 -12 -3434 -43 \. SELECT * FROM copy_test WHERE id <= 10; -COPY copy_test FROM STDIN WHERE id <= 30; +COPY copy_test(id) FROM STDIN; 1 2 3 @@ -35,7 +32,17 @@ COPY copy_test FROM STDIN WHERE id <= 30; 43 \. -SELECT * FROM copy_test WHERE id <= 30 ORDER BY copy_test; +SELECT * FROM copy_test; + +COPY copy_test(id) FROM STDIN; +41 +42 +43 +44 +45 +\. + +SELECT * FROM copy_test; DROP TABLE copy_test;