diff --git a/router/mock/qrouter/mock_qrouter.go b/router/mock/qrouter/mock_qrouter.go index 19ef81102..182b78b55 100644 --- a/router/mock/qrouter/mock_qrouter.go +++ b/router/mock/qrouter/mock_qrouter.go @@ -6,6 +6,7 @@ package mock import ( context "context" + "github.com/pg-sharding/spqr/pkg/models/kr" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -22,6 +23,8 @@ type MockQueryRouter struct { recorder *MockQueryRouterMockRecorder } +var _ qrouter.QueryRouter = &MockQueryRouter{} + // MockQueryRouterMockRecorder is the mock recorder for MockQueryRouter. type MockQueryRouterMockRecorder struct { mock *MockQueryRouter @@ -54,18 +57,18 @@ func (mr *MockQueryRouterMockRecorder) DataShardsRoutes() *gomock.Call { } // DeparseKeyWithRangesInternal mocks base method. -func (m *MockQueryRouter) DeparseKeyWithRangesInternal(ctx context.Context, key string, meta *qrouter.RoutingMetadataContext) (*routingstate.DataShardRoute, error) { +func (m *MockQueryRouter) DeparseKeyWithRangesInternal(ctx context.Context, key string, krs []*kr.KeyRange) (*routingstate.DataShardRoute, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeparseKeyWithRangesInternal", ctx, key, meta) + ret := m.ctrl.Call(m, "DeparseKeyWithRangesInternal", ctx, key, krs) ret0, _ := ret[0].(*routingstate.DataShardRoute) ret1, _ := ret[1].(error) return ret0, ret1 } // DeparseKeyWithRangesInternal indicates an expected call of DeparseKeyWithRangesInternal. -func (mr *MockQueryRouterMockRecorder) DeparseKeyWithRangesInternal(ctx, key, meta interface{}) *gomock.Call { +func (mr *MockQueryRouterMockRecorder) DeparseKeyWithRangesInternal(ctx, key, krs interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeparseKeyWithRangesInternal", reflect.TypeOf((*MockQueryRouter)(nil).DeparseKeyWithRangesInternal), ctx, key, meta) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeparseKeyWithRangesInternal", reflect.TypeOf((*MockQueryRouter)(nil).DeparseKeyWithRangesInternal), ctx, key, krs) } // Initialize mocks base method. diff --git a/router/qrouter/explain.go b/router/qrouter/explain.go index a0b8cce54..31570a2a6 100644 --- a/router/qrouter/explain.go +++ b/router/qrouter/explain.go @@ -11,7 +11,7 @@ import ( // TODO : unit tests func (qr *ProxyQrouter) Explain(ctx context.Context, stmt *lyx.Explain, cli *clientinteractor.PSQLInteractor) error { - meta := NewRoutingMetadataContext(nil, nil, cli.GetDistribution(), nil, nil) + meta := NewRoutingMetadataContext(cli.GetDistribution(), nil, nil) switch node := stmt.Stmt.(type) { case *lyx.VariableSetStmt: diff --git a/router/qrouter/proxy_routing.go b/router/qrouter/proxy_routing.go index 5e24253b8..72dedb4af 100644 --- a/router/qrouter/proxy_routing.go +++ b/router/qrouter/proxy_routing.go @@ -8,7 +8,6 @@ import ( "github.com/pg-sharding/spqr/pkg/config" "github.com/pg-sharding/spqr/pkg/models/hashfunction" "github.com/pg-sharding/spqr/pkg/models/kr" - "github.com/pg-sharding/spqr/pkg/models/shrule" "github.com/pg-sharding/spqr/pkg/session" "github.com/pg-sharding/spqr/pkg/spqrlog" "github.com/pg-sharding/spqr/qdb" @@ -64,8 +63,6 @@ type RoutingMetadataContext struct { // INSERT INTO x (...) SELECT 7 TargetList []lyx.Node - rls []*shrule.ShardingRule - krs []*kr.KeyRange distribution string params [][]byte @@ -73,30 +70,13 @@ type RoutingMetadataContext struct { // TODO: include client ops and metadata here } -func (m *RoutingMetadataContext) CheckColumnRls(colname string) bool { - for i := range m.rls { - for _, c := range m.rls[i].Entries() { - if c.Column == colname { - return true - } - } - } - return false -} - -func NewRoutingMetadataContext( - krs []*kr.KeyRange, - rls []*shrule.ShardingRule, - ds string, - params [][]byte, paramsFormatCodes []int16) *RoutingMetadataContext { +func NewRoutingMetadataContext(ds string, params [][]byte, paramsFormatCodes []int16) *RoutingMetadataContext { meta := &RoutingMetadataContext{ rels: map[RelationFQN][]string{}, tableAliases: map[string]RelationFQN{}, exprs: map[RelationFQN]map[string]string{}, unparsed_columns: map[string]struct{}{}, - krs: krs, - rls: rls, distribution: ds, params: params, } @@ -171,19 +151,19 @@ func (qr *ProxyQrouter) DeparseExprShardingEntries(expr lyx.Node, meta *RoutingM } // TODO : unit tests -func (qr *ProxyQrouter) DeparseKeyWithRangesInternal(ctx context.Context, key string, meta *RoutingMetadataContext) (*routingstate.DataShardRoute, error) { +func (qr *ProxyQrouter) DeparseKeyWithRangesInternal(_ context.Context, key string, krs []*kr.KeyRange) (*routingstate.DataShardRoute, error) { spqrlog.Zero.Debug(). Str("key", key). Msg("checking key") spqrlog.Zero.Debug(). Str("key", key). - Int("key-ranges-count", len(meta.krs)). + Int("key-ranges-count", len(krs)). Msg("checking key with key ranges") var matched_krkey *kr.KeyRange = nil - for _, krkey := range meta.krs { + for _, krkey := range krs { if kr.CmpRangesLessEqual(krkey.LowerBound, []byte(key)) && (matched_krkey == nil || kr.CmpRangesLessEqual(matched_krkey.LowerBound, krkey.LowerBound)) { matched_krkey = krkey @@ -205,7 +185,7 @@ func (qr *ProxyQrouter) DeparseKeyWithRangesInternal(ctx context.Context, key st } // TODO : unit tests -func (qr *ProxyQrouter) RouteKeyWithRanges(ctx context.Context, expr lyx.Node, meta *RoutingMetadataContext, hf hashfunction.HashFunctionType) (*routingstate.DataShardRoute, error) { +func (qr *ProxyQrouter) RouteKeyWithRanges(ctx context.Context, expr lyx.Node, meta *RoutingMetadataContext, krs []*kr.KeyRange, hf hashfunction.HashFunctionType) (*routingstate.DataShardRoute, error) { switch e := expr.(type) { case *lyx.ParamRef: if e.Number > len(meta.params) { @@ -233,7 +213,7 @@ func (qr *ProxyQrouter) RouteKeyWithRanges(ctx context.Context, expr lyx.Node, m } spqrlog.Zero.Debug().Str("key", string(routeParam)).Str("hashed key", string(hashedKey)).Msg("applying hash function on key") - return qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), meta) + return qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) case *lyx.AExprSConst: hashedKey, err := hashfunction.ApplyHashFunction([]byte(e.Value), hf) if err != nil { @@ -241,7 +221,7 @@ func (qr *ProxyQrouter) RouteKeyWithRanges(ctx context.Context, expr lyx.Node, m } spqrlog.Zero.Debug().Str("key", e.Value).Str("hashed key", string(hashedKey)).Msg("applying hash function on key") - return qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), meta) + return qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) case *lyx.AExprIConst: val := fmt.Sprintf("%d", e.Value) hashedKey, err := hashfunction.ApplyHashFunction([]byte(val), hf) @@ -250,7 +230,7 @@ func (qr *ProxyQrouter) RouteKeyWithRanges(ctx context.Context, expr lyx.Node, m } spqrlog.Zero.Debug().Int("key", e.Value).Str("hashed key", string(hashedKey)).Msg("applying hash function on key") - return qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), meta) + return qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) default: return nil, ComplexQuery } @@ -746,17 +726,8 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s } /* TODO: delay this until step 2. */ - krs, err := qr.mgr.ListKeyRanges(ctx, queryDistribution) - if err != nil { - return nil, err - } - rls, err := qr.mgr.ListShardingRules(ctx, queryDistribution) - if err != nil { - return nil, err - } - - meta := NewRoutingMetadataContext(krs, rls, queryDistribution, sph.BindParams(), sph.BindParamFormatCodes()) + meta := NewRoutingMetadataContext(queryDistribution, sph.BindParams(), sph.BindParamFormatCodes()) tsa := config.TargetSessionAttrsAny @@ -931,9 +902,16 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s var route_err error for rfqn, cols := range meta.rels { - /* - * - */ + // TODO: check by whole RFQN + ds, err := qr.mgr.GetRelationDistribution(ctx, rfqn.RelationName) + if err != nil { + return nil, err + } + + krs, err := qr.mgr.ListKeyRanges(ctx, ds.Id) + if err != nil { + return nil, err + } if rule, err := MatchShardingRule(ctx, rfqn.RelationName, cols, qr.mgr.QDB()); err != nil { for _, col := range cols { @@ -953,7 +931,7 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s continue } - currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), meta) + currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) if err != nil { route_err = err spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error") @@ -996,6 +974,15 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s } } + ds, err := qr.mgr.GetRelationDistribution(ctx, meta.InsertStmtRel) + if err != nil { + return nil, err + } + krs, err := qr.mgr.ListKeyRanges(ctx, ds.Id) + if err != nil { + return nil, err + } + hf, err := hashfunction.HashFunctionByName(rule.Entries[0].HashFunction) if err != nil { /* failed to resolve hash function */ @@ -1005,7 +992,7 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s meta.offsets = offsets routed := false if len(meta.offsets) != 0 && len(meta.TargetList) > meta.offsets[0] { - currroute, err := qr.RouteKeyWithRanges(ctx, meta.TargetList[meta.offsets[0]], meta, hf) + currroute, err := qr.RouteKeyWithRanges(ctx, meta.TargetList[meta.offsets[0]], meta, krs, hf) if err == nil { /* else failed, ignore */ spqrlog.Zero.Debug(). @@ -1023,7 +1010,7 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s if len(meta.offsets) != 0 && len(meta.ValuesLists) > meta.offsets[0] && !routed && meta.ValuesLists != nil { // only first value from value list - currroute, err := qr.RouteKeyWithRanges(ctx, meta.ValuesLists[meta.offsets[0]], meta, hf) + currroute, err := qr.RouteKeyWithRanges(ctx, meta.ValuesLists[meta.offsets[0]], meta, krs, hf) if err == nil { /* else failed, ignore */ spqrlog.Zero.Debug(). Interface("current-route", currroute). diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index bbb092ea2..9bc33cba3 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -7,7 +7,6 @@ import ( "github.com/pg-sharding/spqr/pkg/config" "github.com/pg-sharding/spqr/pkg/coord/local" "github.com/pg-sharding/spqr/pkg/models/kr" - "github.com/pg-sharding/spqr/pkg/models/shrule" "github.com/pg-sharding/spqr/pkg/session" "github.com/pg-sharding/spqr/qdb" "github.com/pg-sharding/spqr/router/qrouter" @@ -20,42 +19,6 @@ import ( const MemQDBPath = "memqdb.json" -func TestCheckColumnRls(t *testing.T) { - assert := assert.New(t) - - rmc := qrouter.NewRoutingMetadataContext( - nil, - []*shrule.ShardingRule{ - shrule.NewShardingRule( - "", - "", - []shrule.ShardingRuleEntry{ - *shrule.NewShardingRuleEntry("col1", ""), - *shrule.NewShardingRuleEntry("col2", ""), - }, - "", - ), - shrule.NewShardingRule( - "", - "", - []shrule.ShardingRuleEntry{ - *shrule.NewShardingRuleEntry("col3", ""), - }, - "", - ), - }, - "", - nil, - nil, - ) - - assert.True(rmc.CheckColumnRls("col1"), "col1 should be in rls") - assert.True(rmc.CheckColumnRls("col2"), "col2 should be in rls") - assert.True(rmc.CheckColumnRls("col3"), "col3 should be in rls") - - assert.False(rmc.CheckColumnRls("col4"), "col4 should not be in rls") -} - func TestMultiShardRouting(t *testing.T) { assert := assert.New(t) @@ -647,6 +610,14 @@ func TestJoins(t *testing.T) { Name: "sshjt1", ColumnNames: []string{"i"}, }, + "xjoin": { + Name: "xjoin", + ColumnNames: []string{"i"}, + }, + "yjoin": { + Name: "yjoin", + ColumnNames: []string{"i"}, + }, }, }) @@ -713,14 +684,14 @@ func TestJoins(t *testing.T) { }, { - query: "SELECT * FROM xjoin JOIN yjoin on id=w_id where w_idx = 15 ORDER BY id;'", + query: "SELECT * FROM xjoin JOIN yjoin on id=w_id where w_idx = 15 ORDER BY id;", exp: routingstate.MultiMatchState{}, err: nil, }, // sharding columns, but unparsed { - query: "SELECT * FROM xjoin JOIN yjoin on id=w_id where i = 15 ORDER BY id;'", + query: "SELECT * FROM xjoin JOIN yjoin on id=w_id where i = 15 ORDER BY id;", exp: routingstate.ShardMatchState{ Route: &routingstate.DataShardRoute{ Shkey: kr.ShardKey{ @@ -969,126 +940,6 @@ func TestCopySingleShard(t *testing.T) { } } -func TestInsertMultiDistribution(t *testing.T) { - assert := assert.New(t) - - type tcase struct { - query string - distribution string - exp routingstate.RoutingState - err error - } - db, _ := qdb.NewMemQDB(MemQDBPath) - distribution1 := "ds1" - distribution2 := "ds2" - - assert.NoError(db.CreateDistribution(context.TODO(), qdb.NewDistribution(distribution1, nil))) - assert.NoError(db.CreateDistribution(context.TODO(), qdb.NewDistribution(distribution2, nil))) - - assert.NoError(db.AddShardingRule(context.TODO(), &qdb.ShardingRule{ - ID: "id1", - DistributionId: distribution1, - TableName: "", - Entries: []qdb.ShardingRuleEntry{ - { - Column: "i", - }, - }, - })) - - assert.NoError(db.AddShardingRule(context.TODO(), &qdb.ShardingRule{ - ID: "id2", - DistributionId: distribution2, - TableName: "", - Entries: []qdb.ShardingRuleEntry{ - { - Column: "i", - }, - }, - })) - - assert.NoError(db.AddKeyRange(context.TODO(), &qdb.KeyRange{ - ShardID: "sh1", - DistributionId: distribution1, - KeyRangeID: "id1", - LowerBound: []byte("1"), - })) - - assert.NoError(db.AddKeyRange(context.TODO(), &qdb.KeyRange{ - ShardID: "sh2", - DistributionId: distribution2, - KeyRangeID: "id2", - LowerBound: []byte("1"), - })) - - lc := local.NewLocalCoordinator(db) - - pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, - }, lc, &config.QRouter{ - DefaultRouteBehaviour: "BLOCK", - }) - - assert.NoError(err) - - for _, tt := range []tcase{ - { - - query: "INSERT INTO xxxdst1(i) VALUES(5);", - distribution: distribution1, - exp: routingstate.ShardMatchState{ - Route: &routingstate.DataShardRoute{ - Shkey: kr.ShardKey{ - Name: "sh1", - }, - Matchedkr: &kr.KeyRange{ - ShardID: "sh1", - ID: "id1", - Distribution: distribution1, - LowerBound: []byte("1"), - }, - }, - TargetSessionAttrs: "any", - }, - err: nil, - }, - { - query: "INSERT INTO xxxdst1(i) VALUES(5);", - distribution: distribution2, - exp: routingstate.ShardMatchState{ - Route: &routingstate.DataShardRoute{ - Shkey: kr.ShardKey{ - Name: "sh2", - }, - Matchedkr: &kr.KeyRange{ - ShardID: "sh2", - ID: "id2", - Distribution: distribution2, - LowerBound: []byte("1"), - }, - }, - TargetSessionAttrs: "any", - }, - err: nil, - }, - } { - parserRes, err := lyx.Parse(tt.query) - - assert.NoError(err, "query %s", tt.query) - - tmp, err := pr.Route(context.TODO(), parserRes, session.NewDummyHandler(tt.distribution)) - - assert.NoError(err, "query %s", tt.query) - - assert.Equal(tt.exp, tmp, tt.query) - } -} - func TestSetStmt(t *testing.T) { assert := assert.New(t) diff --git a/router/qrouter/qrouter.go b/router/qrouter/qrouter.go index 73c06e681..29282162d 100644 --- a/router/qrouter/qrouter.go +++ b/router/qrouter/qrouter.go @@ -3,6 +3,7 @@ package qrouter import ( "context" "fmt" + "github.com/pg-sharding/spqr/pkg/models/kr" "github.com/pg-sharding/spqr/pkg/config" "github.com/pg-sharding/spqr/pkg/meta" @@ -21,7 +22,7 @@ type QueryRouter interface { WorldShardsRoutes() []*routingstate.DataShardRoute DataShardsRoutes() []*routingstate.DataShardRoute - DeparseKeyWithRangesInternal(ctx context.Context, key string, meta *RoutingMetadataContext) (*routingstate.DataShardRoute, error) + DeparseKeyWithRangesInternal(ctx context.Context, key string, krs []*kr.KeyRange) (*routingstate.DataShardRoute, error) Initialized() bool Initialize() bool diff --git a/router/relay/qstate.go b/router/relay/qstate.go index 2bc6fbf47..d6ec350ac 100644 --- a/router/relay/qstate.go +++ b/router/relay/qstate.go @@ -14,7 +14,6 @@ import ( "github.com/pg-sharding/spqr/pkg/spqrlog" "github.com/pg-sharding/spqr/pkg/txstatus" "github.com/pg-sharding/spqr/router/parser" - "github.com/pg-sharding/spqr/router/qrouter" "github.com/pg-sharding/spqr/router/routehint" "github.com/pg-sharding/spqr/router/routingstate" "github.com/pg-sharding/spqr/router/statistics" @@ -31,19 +30,18 @@ func deparseRouteHint(rst RelayStateMgr, params map[string]string, distribution if val, ok := params[session.SPQR_SHARDING_KEY]; ok { spqrlog.Zero.Debug().Str("sharding key", val).Msg("checking hint key") - krs, err := rst.QueryRouter().Mgr().ListKeyRanges(context.TODO(), distribution) - - if err != nil { - return nil, err + dsId := "" + if dsId, ok = params[session.SPQR_DISTRIBUTION]; !ok { + return nil, spqrerror.New(spqrerror.SPQR_NO_DISTRIBUTION, "got sharding key in comment without distribution") } - rls, err := rst.QueryRouter().Mgr().ListShardingRules(context.TODO(), distribution) + ctx := context.TODO() + krs, err := rst.QueryRouter().Mgr().ListKeyRanges(ctx, dsId) if err != nil { return nil, err } - meta := qrouter.NewRoutingMetadataContext(krs, rls, distribution, nil, nil) - ds, err := rst.QueryRouter().DeparseKeyWithRangesInternal(context.TODO(), val, meta) + ds, err := rst.QueryRouter().DeparseKeyWithRangesInternal(context.TODO(), val, krs) if err != nil { return nil, err }