diff --git a/router/qrouter/proxy_routing.go b/router/qrouter/proxy_routing.go index 04f4af212..b8802826a 100644 --- a/router/qrouter/proxy_routing.go +++ b/router/qrouter/proxy_routing.go @@ -43,7 +43,7 @@ type RoutingMetadataContext struct { // SELECT * FROM a join b WHERE a.c1 = and a.c2 = // can be routed with different rules rels map[RelationFQN]struct{} - exprs map[RelationFQN]map[string]string + exprs map[RelationFQN]map[string][]string // cached CTE names cteNames map[string]struct{} @@ -65,7 +65,7 @@ func NewRoutingMetadataContext(params [][]byte, paramsFormatCodes []int16) *Rout rels: map[RelationFQN]struct{}{}, cteNames: map[string]struct{}{}, tableAliases: map[string]RelationFQN{}, - exprs: map[RelationFQN]map[string]string{}, + exprs: map[RelationFQN]map[string][]string{}, unparsed_columns: map[string]struct{}{}, params: params, } @@ -97,17 +97,21 @@ func (meta *RoutingMetadataContext) RFQNIsCTE(resolvedRelation RelationFQN) bool } // TODO : unit tests -func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN, colname string, expr string) { +func (meta *RoutingMetadataContext) RecordConstExpr(resolvedRelation RelationFQN, colname string, expr string) error { if meta.RFQNIsCTE(resolvedRelation) { // CTE, skip - return + return nil } meta.rels[resolvedRelation] = struct{}{} if _, ok := meta.exprs[resolvedRelation]; !ok { - meta.exprs[resolvedRelation] = map[string]string{} + meta.exprs[resolvedRelation] = map[string][]string{} } delete(meta.unparsed_columns, colname) - meta.exprs[resolvedRelation][colname] = expr + if _, ok := meta.exprs[resolvedRelation][colname]; !ok { + meta.exprs[resolvedRelation][colname] = make([]string, 0) + } + meta.exprs[resolvedRelation][colname] = append(meta.exprs[resolvedRelation][colname], expr) + return nil } // TODO : unit tests @@ -185,11 +189,11 @@ func (qr *ProxyQrouter) DeparseKeyWithRangesInternal(_ context.Context, key stri return nil, FailedToFindKeyRange } -func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMetadataContext, resolvedRelation RelationFQN, colname, value string) { +func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMetadataContext, resolvedRelation RelationFQN, colname, value string) error { /* do not process non-distributed relations or columns not from relation distribution key */ if ds, err := qr.Mgr().GetRelationDistribution(context.TODO(), resolvedRelation.RelationName); err != nil { - return + return nil } else { // TODO: optimize ok := false @@ -201,12 +205,12 @@ func (qr *ProxyQrouter) RecordDistributionKeyColumnValueOnRFQN(meta *RoutingMeta } if !ok { // some junk column - return + return nil } } // will not work not ints - meta.RecordConstExpr(resolvedRelation, colname, value) + return meta.RecordConstExpr(resolvedRelation, colname, value) } // TODO : unit tests @@ -232,30 +236,27 @@ func (qr *ProxyQrouter) RecordDistributionKeyExprOnRFQN(meta *RoutingMetadataCon // ??? protoc violation } - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(routeParam)) - return nil + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(routeParam)) case *lyx.AExprSConst: - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(e.Value)) - return nil + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(e.Value)) case *lyx.AExprIConst: val := fmt.Sprintf("%d", e.Value) - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(val)) - return nil + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, string(val)) default: return ComplexQuery } } -func (qr *ProxyQrouter) RecordDistributionKeyColumnValue(meta *RoutingMetadataContext, alias, colname, value string) { +func (qr *ProxyQrouter) RecordDistributionKeyColumnValue(meta *RoutingMetadataContext, alias, colname, value string) error { resolvedRelation, err := meta.ResolveRelationByAlias(alias) if err != nil { // failed to resolve relation, skip column meta.unparsed_columns[colname] = struct{}{} - return + return nil } - qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, value) + return qr.RecordDistributionKeyColumnValueOnRFQN(meta, resolvedRelation, colname, value) } // routeByClause de-parses sharding column-value pair from Where clause of the query @@ -281,27 +282,37 @@ func (qr *ProxyQrouter) routeByClause(ctx context.Context, expr lyx.Node, meta * switch rght := texpr.Right.(type) { case *lyx.ParamRef: if rght.Number <= len(meta.params) { - qr.RecordDistributionKeyColumnValue(meta, alias, colname, string(meta.params[rght.Number-1])) + if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, string(meta.params[rght.Number-1])); err != nil { + return err + } } // else error out? case *lyx.AExprSConst: // TBD: postpone routing from here to root of parsing tree - qr.RecordDistributionKeyColumnValue(meta, alias, colname, rght.Value) + if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, rght.Value); err != nil { + return err + } case *lyx.AExprIConst: // TBD: postpone routing from here to root of parsing tree // maybe expimely inefficient. Will be fixed in SPQR-2.0 - qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", rght.Value)) + if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", rght.Value)); err != nil { + return err + } case *lyx.AExprList: if len(rght.List) != 0 { expr := rght.List[0] switch bexpr := expr.(type) { case *lyx.AExprSConst: // TBD: postpone routing from here to root of parsing tree - qr.RecordDistributionKeyColumnValue(meta, alias, colname, bexpr.Value) + if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, bexpr.Value); err != nil { + return err + } case *lyx.AExprIConst: // TBD: postpone routing from here to root of parsing tree // maybe expimely inefficient. Will be fixed in SPQR-2.0 - qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", bexpr.Value)) + if err := qr.RecordDistributionKeyColumnValue(meta, alias, colname, fmt.Sprintf("%d", bexpr.Value)); err != nil { + return err + } } } case *lyx.FuncApplication: @@ -900,7 +911,7 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s ok := true - var hashedKey []byte + var hashedKeys [][]byte // TODO: multi-column routing. This works only for one-dim routing for i := 0; i < len(distrKey); i++ { @@ -913,20 +924,22 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s col := distrKey[i].Column - val, valOk := meta.exprs[rfqn][col] + vals, valOk := meta.exprs[rfqn][col] if !valOk { ok = false break } - hashedKey, err = hashfunction.ApplyHashFunction([]byte(val), hf) - - spqrlog.Zero.Debug().Str("key", meta.exprs[rfqn][col]).Str("hashed key", string(hashedKey)).Msg("applying hash function on key") + hashedKeys = make([][]byte, len(vals)) + for i, val := range vals { + hashedKeys[i], err = hashfunction.ApplyHashFunction([]byte(val), hf) + spqrlog.Zero.Debug().Str("key", meta.exprs[rfqn][col][i]).Str("hashed key", string(hashedKeys[i])).Msg("applying hash function on key") - if err != nil { - spqrlog.Zero.Debug().Err(err).Msg("failed to apply hash function") - ok = false - break + if err != nil { + spqrlog.Zero.Debug().Err(err).Msg("failed to apply hash function") + ok = false + break + } } } @@ -934,24 +947,24 @@ func (qr *ProxyQrouter) routeWithRules(ctx context.Context, stmt lyx.Node, sph s // skip this relation continue } + for _, hashedKey := range hashedKeys { + currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) + if err != nil { + route_err = err + spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error") + continue + } - currroute, err := qr.DeparseKeyWithRangesInternal(ctx, string(hashedKey), krs) - if err != nil { - route_err = err - spqrlog.Zero.Debug().Err(route_err).Msg("temporarily skip the route error") - continue - } - - spqrlog.Zero.Debug(). - Interface("currroute", currroute). - Str("table", rfqn.RelationName). - Msg("calculated route for table/cols") - - route = routingstate.Combine(route, routingstate.ShardMatchState{ - Route: currroute, - TargetSessionAttrs: tsa, - }) + spqrlog.Zero.Debug(). + Interface("currroute", currroute). + Str("table", rfqn.RelationName). + Msg("calculated route for table/cols") + route = routingstate.Combine(route, routingstate.ShardMatchState{ + Route: currroute, + TargetSessionAttrs: tsa, + }) + } } if route == nil && route_err != nil { return nil, route_err diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index 8c62d3c9e..a72506229 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -320,6 +320,49 @@ func TestCTE(t *testing.T) { TargetSessionAttrs: "any", }, }, + { + query: ` + WITH xxxx AS ( + SELECT * from t where i = 1 + ), + zzzz AS ( + UPDATE t + SET a = 0 + WHERE i = 12 + ) + SELECT * FROM xxxx; + `, + err: nil, + exp: routingstate.SkipRoutingState{}, + }, + { + query: ` + WITH xxxx AS ( + SELECT * from t where i = 1 + ), + zzzz AS ( + UPDATE t + SET a = 0 + WHERE i = 2 + ) + SELECT * FROM xxxx; + `, + err: nil, + exp: routingstate.ShardMatchState{ + Route: &routingstate.DataShardRoute{ + Shkey: kr.ShardKey{ + Name: "sh1", + }, + Matchedkr: &kr.KeyRange{ + ShardID: "sh1", + ID: "id1", + Distribution: distribution, + LowerBound: []byte("1"), + }, + }, + TargetSessionAttrs: "any", + }, + }, } { parserRes, err := lyx.Parse(tt.query) @@ -327,7 +370,11 @@ func TestCTE(t *testing.T) { tmp, err := pr.Route(context.TODO(), parserRes, session.NewDummyHandler(distribution)) - assert.NoError(err, "query %s", tt.query) + if tt.err == nil { + assert.NoError(err, "query %s", tt.query) + } else { + assert.Error(err, "query %s", tt.query) + } assert.Equal(tt.exp, tmp, tt.query) } @@ -612,6 +659,24 @@ func TestSingleShard(t *testing.T) { err: nil, }, + { + query: "SELECT * FROM t WHERE i = 12 AND j = 1;", + exp: routingstate.ShardMatchState{ + Route: &routingstate.DataShardRoute{ + Shkey: kr.ShardKey{ + Name: "sh2", + }, + Matchedkr: &kr.KeyRange{ + ShardID: "sh2", + ID: "id2", + Distribution: distribution, + LowerBound: []byte("11"), + }, + }, + TargetSessionAttrs: "any", + }, + err: nil, + }, { query: "SELECT * FROM t WHERE i = 12 UNION ALL SELECT * FROM xxmixed WHERE i = 22;", exp: routingstate.ShardMatchState{ diff --git a/test/regress/tests/router/expected/single_shard_joins.out b/test/regress/tests/router/expected/single_shard_joins.out index ed322c854..15d75710c 100644 --- a/test/regress/tests/router/expected/single_shard_joins.out +++ b/test/regress/tests/router/expected/single_shard_joins.out @@ -47,7 +47,7 @@ NOTICE: send query to shard(s) : sh2 12 | 13 (2 rows) -SELECT * FROM sshjt1 WHERE i = 12 AND j =1; +SELECT * FROM sshjt1 WHERE i = 12 AND j = 1; NOTICE: send query to shard(s) : sh2 i | j ---+--- diff --git a/test/regress/tests/router/sql/single_shard_joins.sql b/test/regress/tests/router/sql/single_shard_joins.sql index f78c4bea6..aae5b28f8 100644 --- a/test/regress/tests/router/sql/single_shard_joins.sql +++ b/test/regress/tests/router/sql/single_shard_joins.sql @@ -14,7 +14,7 @@ INSERT INTO sshjt1 (i, j) VALUES(12, 12); INSERT INTO sshjt1 (i, j) VALUES(12, 13); SELECT * FROM sshjt1 WHERE i = 12; -SELECT * FROM sshjt1 WHERE i = 12 AND j =1; +SELECT * FROM sshjt1 WHERE i = 12 AND j = 1; SELECT * FROM sshjt1 a join sshjt1 b WHERE a.i = 12 ON TRUE; SELECT * FROM sshjt1 a join sshjt1 b ON TRUE WHERE a.i = 12;