Skip to content

Commit

Permalink
Support COPY (#748)
Browse files Browse the repository at this point in the history
Parse COPY command and extract table and column names and options. Send the COPY query to all shards. Receive all CopyData and parse them. Check if they correspond to one shard. Send CopyData to that one shard. Then send CopyDone to all shards and receive responses.
  • Loading branch information
diPhantxm authored Sep 14, 2024
1 parent 6f7dc6d commit 924330d
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 92 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ require (
github.com/lib/pq v1.10.9
github.com/libp2p/go-reuseport v0.4.0
github.com/opentracing/opentracing-go v1.2.0
github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1
github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.33.0
github.com/sevlyar/go-daemon v0.1.6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrB
github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1 h1:AwlQkwnrqRyL8lqZTTAzfQ09niEc+6oFiDvQkMImTPE=
github.com/pg-sharding/lyx v0.0.0-20240819153240-bbdc782d01c1/go.mod h1:2dPBQAhqv/30mhzj2yBXQkXhsGJQ8GhM+oWOfbGua58=
github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c h1:4sXBG7ZDtG/rN2jqgmzsMawfcTKQvTCTTo8iQ7eR6VU=
github.com/pg-sharding/lyx v0.0.0-20240823123817-e655173c284c/go.mod h1:2dPBQAhqv/30mhzj2yBXQkXhsGJQ8GhM+oWOfbGua58=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
Expand Down
46 changes: 31 additions & 15 deletions router/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ func TestFrontendSimpleCopyIn(t *testing.T) {
qr := mockqr.NewMockQueryRouter(ctrl)
cmngr := mockcmgr.NewMockPoolMgr(ctrl)

sh1 := mocksh.NewMockShard(ctrl)
sh1.EXPECT().Name().AnyTimes().Return("sh1")
sh1.EXPECT().SHKey().AnyTimes().Return(kr.ShardKey{Name: "sh1"})
sh1.EXPECT().ID().AnyTimes().Return(uint(1))
sh2 := mocksh.NewMockShard(ctrl)
sh2.EXPECT().Name().AnyTimes().Return("sh2")
sh2.EXPECT().SHKey().AnyTimes().Return(kr.ShardKey{Name: "sh2"})
sh2.EXPECT().ID().AnyTimes().Return(uint(2))

frrule := &config.FrontendRule{
DB: "db1",
Usr: "user1",
Expand All @@ -337,7 +346,7 @@ func TestFrontendSimpleCopyIn(t *testing.T) {
beRule := &config.BackendRule{}

srv.EXPECT().Name().AnyTimes().Return("serv1")
srv.EXPECT().Datashards().AnyTimes().Return([]shard.Shard{})
srv.EXPECT().Datashards().AnyTimes().Return([]shard.Shard{sh1, sh2})

cl.EXPECT().Server().AnyTimes().Return(srv)
cl.EXPECT().MaintainParams().AnyTimes().Return(false)
Expand Down Expand Up @@ -375,22 +384,29 @@ func TestFrontendSimpleCopyIn(t *testing.T) {

cmngr.EXPECT().TXEndCB(gomock.Any()).AnyTimes()

tableref := &lyx.RangeVar{
RelationName: "xx",
}

qr.EXPECT().Route(gomock.Any(), &lyx.Copy{
TableRef: &lyx.RangeVar{
RelationName: "xx",
},
Where: &lyx.AExprEmpty{},
IsFrom: true,
}, gomock.Any()).Return(routingstate.ShardMatchState{
Route: &routingstate.DataShardRoute{
Shkey: kr.ShardKey{
Name: "sh1",
},
},
}, nil).Times(1)
TableRef: tableref,
Where: &lyx.AExprEmpty{},
IsFrom: true,
}, gomock.Any()).Return(routingstate.MultiMatchState{}, nil).Times(1)

qr.EXPECT().Route(gomock.Any(), &lyx.Insert{
TableRef: tableref,
SubSelect: &lyx.ValueClause{Values: []lyx.Node{&lyx.AExprSConst{Value: "1"}}},
}, cl).Times(4).Return(routingstate.ShardMatchState{Route: &routingstate.DataShardRoute{Shkey: sh1.SHKey()}}, nil)

qr.EXPECT().DataShardsRoutes().AnyTimes().Return([]*routingstate.DataShardRoute{
&routingstate.DataShardRoute{Shkey: sh1.SHKey()},
&routingstate.DataShardRoute{Shkey: sh2.SHKey()}},
)

route := route.NewRoute(beRule, frrule, map[string]*config.Shard{
"sh1": {},
"sh2": {},
})

cl.EXPECT().Route().AnyTimes().Return(route)
Expand All @@ -401,12 +417,12 @@ func TestFrontendSimpleCopyIn(t *testing.T) {

cl.EXPECT().Receive().Times(1).Return(query, nil)

cl.EXPECT().Receive().Times(4).Return(&pgproto3.CopyData{}, nil)
cl.EXPECT().Receive().Times(4).Return(&pgproto3.CopyData{Data: []byte("1\n")}, nil)
cl.EXPECT().Receive().Times(1).Return(&pgproto3.CopyDone{}, nil)

srv.EXPECT().Send(query).Times(1).Return(nil)

srv.EXPECT().Send(&pgproto3.CopyData{}).Times(4).Return(nil)
sh1.EXPECT().Send(&pgproto3.CopyData{Data: []byte("1\n")}).Times(4).Return(nil)
srv.EXPECT().Send(&pgproto3.CopyDone{}).Times(1).Return(nil)

srv.EXPECT().Receive().Times(1).Return(&pgproto3.CopyInResponse{}, nil)
Expand Down
19 changes: 3 additions & 16 deletions router/qrouter/proxy_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,21 +667,6 @@ func (qr *ProxyQrouter) deparseShardingMapping(

_ = qr.deparseFromNode(stmt.TableRef, meta)

return qr.routeByClause(ctx, clause, meta)
case *lyx.Copy:
if !stmt.IsFrom {
return fmt.Errorf("copy from stdin is not implemented")
}

_ = qr.deparseFromNode(stmt.TableRef, meta)

clause := stmt.Where

if clause == nil {
// will not work
return nil
}

return qr.routeByClause(ctx, clause, meta)
}

Expand Down Expand Up @@ -1001,13 +986,15 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s
return routingstate.RandomMatchState{}, nil
}

case *lyx.Delete, *lyx.Update, *lyx.Copy:
case *lyx.Delete, *lyx.Update:
// UPDATE and/or DELETE, COPY stmts, which
// would be routed with their WHERE clause
err := qr.deparseShardingMapping(ctx, stmt, meta)
if err != nil {
return nil, err
}
case *lyx.Copy:
return routingstate.MultiMatchState{}, nil
default:
spqrlog.Zero.Debug().Interface("statement", stmt).Msg("proxy-routing message to all shards")
}
Expand Down
21 changes: 2 additions & 19 deletions router/qrouter/proxy_routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1426,25 +1426,8 @@ func TestCopySingleShard(t *testing.T) {
for _, tt := range []tcase{
{
query: "COPY xx FROM STDIN WHERE i = 1;",
exp: routingstate.ShardMatchState{
Route: &routingstate.DataShardRoute{
Shkey: kr.ShardKey{
Name: "sh1",
},
Matchedkr: &kr.KeyRange{
ShardID: "sh1",
ID: "id1",
Distribution: distribution,
LowerBound: []interface{}{
int64(1),
},

ColumnTypes: []string{qdb.ColumnTypeInteger},
},
},
TargetSessionAttrs: "any",
},
err: nil,
exp: routingstate.MultiMatchState{},
err: nil,
},
} {
parserRes, err := lyx.Parse(tt.query)
Expand Down
95 changes: 79 additions & 16 deletions router/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"math/rand"
"strings"
"time"

"github.com/pg-sharding/lyx/lyx"
Expand Down Expand Up @@ -67,7 +68,7 @@ type RelayStateMgr interface {
RelayRunCommand(msg pgproto3.FrontendMessage, waitForResp bool, replyCl bool) error

ProcQuery(query pgproto3.FrontendMessage, waitForResp bool, replyCl bool) (txstatus.TXStatus, []pgproto3.BackendMessage, bool, error)
ProcCopy(query pgproto3.FrontendMessage) error
ProcCopy(stmt *lyx.Copy, data *pgproto3.CopyData, expRoute *routingstate.DataShardRoute) error

ProcCommand(query pgproto3.FrontendMessage, waitForResp bool, replyCl bool) error

Expand Down Expand Up @@ -656,15 +657,72 @@ func (rst *RelayStateImpl) RelayRunCommand(msg pgproto3.FrontendMessage, waitFor
}

// TODO : unit tests
func (rst *RelayStateImpl) ProcCopy(query pgproto3.FrontendMessage) error {
func (rst *RelayStateImpl) ProcCopy(stmt *lyx.Copy, data *pgproto3.CopyData, expRoute *routingstate.DataShardRoute) error {
spqrlog.Zero.Debug().
Uint("client", rst.Client().ID()).
Type("query-type", query).
Msg("client process copy")
_ = rst.Client().ReplyDebugNotice(fmt.Sprintf("executing your query %v", query)) // TODO perfomance issue
_ = rst.Client().ReplyDebugNotice(fmt.Sprintf("executing your query %v", data)) // TODO perfomance issue
rst.Client().RLock()
defer rst.Client().RUnlock()
return rst.Client().Server().Send(query)

// Read delimiter from COPY options
delimiter := byte('\t')
for _, opt := range stmt.Options {
if o := opt.(*lyx.Option); strings.ToLower(o.Name) == "delimiter" {
delimiter = o.Arg.(*lyx.AExprSConst).Value[0]
}
}

// Parse data
// and decide where to route
prevDelimiter := 0
prevLine := 0
valueClause := &lyx.ValueClause{}
for i, b := range data.Data {
if i+2 < len(data.Data) && string(data.Data[i:i+2]) == "\\." {
prevLine = len(data.Data)
break
}
if b == '\n' || b == delimiter {
valueClause.Values = append(valueClause.Values, &lyx.AExprSConst{Value: string(data.Data[prevDelimiter:i])})
prevDelimiter = i + 1
}
if b != '\n' {
continue
}

// check where this tuple should go
r, err := rst.QueryRouter().Route(context.TODO(), &lyx.Insert{TableRef: stmt.TableRef, Columns: stmt.Columns, SubSelect: valueClause}, rst.Cl)
if err != nil {
return err
}

smt, ok := r.(routingstate.ShardMatchState)
if !ok {
return fmt.Errorf("multishard copy is not supported")
}

if expRoute.Shkey.Name == "" {
*expRoute = *smt.Route
}
if smt.Route.Shkey.Name != expRoute.Shkey.Name {
return fmt.Errorf("multishard copy is not supported")
}

valueClause = &lyx.ValueClause{}
prevLine = i + 1
}

for _, sh := range rst.Client().Server().Datashards() {
if expRoute != nil && sh.Name() == expRoute.Shkey.Name {
err := sh.Send(&pgproto3.CopyData{Data: data.Data[:prevLine]})
data.Data = data.Data[prevLine:]
return err
}
}

// shouldn't exit from here
return nil
}

// TODO : unit tests
Expand All @@ -680,16 +738,16 @@ func (rst *RelayStateImpl) ProcCopyComplete(query *pgproto3.FrontendMessage) err
}

for {
if msg, err := rst.Client().Server().Receive(); err != nil {
msg, err := rst.Client().Server().Receive()
if err != nil {
return err
} else {
switch msg.(type) {
case *pgproto3.CommandComplete, *pgproto3.ErrorResponse:
return rst.Client().Send(msg)
default:
if err := rst.Client().Send(msg); err != nil {
return err
}
}
switch msg.(type) {
case *pgproto3.CommandComplete, *pgproto3.ErrorResponse:
return rst.Client().Send(msg)
default:
if err := rst.Client().Send(msg); err != nil {
return err
}
}
}
Expand Down Expand Up @@ -748,16 +806,21 @@ func (rst *RelayStateImpl) ProcQuery(query pgproto3.FrontendMessage, waitForResp
return txstatus.TXERR, nil, false, err
}

q := rst.qp.Stmt().(*lyx.Copy)

if err := func() error {
msg := &pgproto3.CopyData{Data: make([]byte, 0)}
route := &routingstate.DataShardRoute{}
for {
cpMsg, err := rst.Client().Receive()
if err != nil {
return err
}

switch cpMsg.(type) {
switch newMsg := cpMsg.(type) {
case *pgproto3.CopyData:
if err := rst.ProcCopy(cpMsg); err != nil {
msg.Data = append(msg.Data, newMsg.Data...)
if err = rst.ProcCopy(q, msg, route); err != nil {
return err
}
case *pgproto3.CopyDone, *pgproto3.CopyFail:
Expand Down
31 changes: 24 additions & 7 deletions router/server/multishard.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ const (
RunningState
ServerErrorState
CommandCompleteState
CopyState
CopyOutState
CopyInState
)

type MultiShardServer struct {
Expand Down Expand Up @@ -178,7 +179,8 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) {
var saveRd *pgproto3.RowDescription = nil
var saveCC *pgproto3.CommandComplete = nil
var saveRFQ *pgproto3.ReadyForQuery = nil
/* Step one: ensure all shard backend are stared */
var saveCIn *pgproto3.CopyInResponse = nil
/* Step one: ensure all shard backend are started */
for i := range m.activeShards {
for {
// all shards should be in rfq state
Expand All @@ -200,12 +202,19 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) {

switch retMsg := msg.(type) {
case *pgproto3.CopyOutResponse:
if m.multistate != InitialState && m.multistate != CopyState {
if m.multistate != InitialState && m.multistate != CopyOutState {
return nil, MultiShardSyncBroken
}
m.states[i] = ShardCopyState
m.multistate = CopyState
m.multistate = CopyOutState
m.copyBuf = append(m.copyBuf, retMsg)
case *pgproto3.CopyInResponse:
if m.multistate != InitialState && m.multistate != CopyInState {
return nil, MultiShardSyncBroken
}
m.states[i] = ShardCopyState
m.multistate = CopyInState
saveCIn = retMsg
case *pgproto3.CommandComplete:
m.states[i] = ShardCCState
saveCC = retMsg //
Expand Down Expand Up @@ -257,7 +266,7 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) {
m.multistate = InitialState
return saveRFQ, nil
}
if m.multistate == CopyState {
if m.multistate == CopyOutState {
n := len(m.copyBuf)
var currMsg *pgproto3.CopyOutResponse
m.copyBuf, currMsg = m.copyBuf[n-2:], m.copyBuf[n-1]
Expand All @@ -267,10 +276,14 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) {
Msg("miltishard server: flush copy buff")
return currMsg, nil
}
if m.multistate == CopyInState {
m.multistate = RunningState
return saveCIn, nil
}

m.multistate = RunningState
return saveRd, nil
case CopyState:
case CopyOutState:
if len(m.copyBuf) > 0 {
spqrlog.Zero.Debug().Msg("miltishard server: flush copy buff")
n := len(m.copyBuf)
Expand Down Expand Up @@ -302,7 +315,7 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) {
spqrlog.Zero.Info().
Uint("shard", m.activeShards[i].ID()).
Type("message-type", msg).
Msg("multishard server: recived message from shard")
Msg("multishard server: received message from shard")

switch msg.(type) {
case *pgproto3.CommandComplete:
Expand All @@ -322,6 +335,10 @@ func (m *MultiShardServer) Receive() (pgproto3.BackendMessage, error) {
return &pgproto3.CommandComplete{
CommandTag: []byte{}, // XXX : fix this
}, nil
case CopyInState:
return &pgproto3.CommandComplete{
CommandTag: []byte{},
}, nil
case RunningState:
/* Step two: fetch all datarow ms gs */
for i := range m.activeShards {
Expand Down
Loading

0 comments on commit 924330d

Please sign in to comment.